package irsymcache // import "go.opentelemetry.io/ebpf-profiler/pyroscope/symb/irsymcache"

import (
	"debug/elf"
	"errors"
	"fmt"
	"os"
	"path/filepath"
	"runtime"
	"strconv"
	"sync"
	"syscall"
	"time"

	lru "github.com/elastic/go-freelru"
	"go.opentelemetry.io/ebpf-profiler/host"
	"go.opentelemetry.io/ebpf-profiler/reporter/samples"

	"github.com/sirupsen/logrus"

	"go.opentelemetry.io/ebpf-profiler/libpf/pfelf"
	"go.opentelemetry.io/ebpf-profiler/process"
)

var errUnknownFile = errors.New("unknown file")

type cachedMarker int

var cached cachedMarker = 1
var erroredMarker cachedMarker = 2

type Table interface {
	Lookup(addr uint64) (samples.SourceInfo, error)
	Close()
}

type TableFactory interface {
	ConvertTable(src *os.File, dst *os.File) error
	OpenTable(path string) (Table, error)
	Name() string
}

func NewTableFactory() TableFactory {
	return TableTableFactory{}
}

type Resolver struct {
	logger   *logrus.Entry
	f        TableFactory
	cacheDir string
	cache    *lru.SyncedLRU[host.FileID, cachedMarker]
	jobs     chan convertJob
	wg       sync.WaitGroup

	mutex    sync.Mutex
	tables   map[host.FileID]Table
	shutdown chan struct{}
}

func (c *Resolver) Cleanup() {
	c.mutex.Lock()
	defer c.mutex.Unlock()

	for _, table := range c.tables {
		table.Close()
	}
	clear(c.tables)
}

type convertJob struct {
	src *os.File
	dst *os.File

	result chan error
}

type Options struct {
	Path        string
	SizeEntries uint32
}

func NewFSCache(impl TableFactory, opt Options) (*Resolver, error) {
	l := logrus.WithField("component", "irsymtab")
	l.WithFields(logrus.Fields{
		"path": opt.Path,
		"size": opt.SizeEntries,
	}).Debug()

	shutdown := make(chan struct{})
	res := &Resolver{
		logger:   l,
		f:        impl,
		cacheDir: opt.Path,
		jobs:     make(chan convertJob, 1),
		shutdown: shutdown,
		tables:   make(map[host.FileID]Table),
	}
	res.cacheDir = filepath.Join(res.cacheDir, impl.Name())

	if err := os.MkdirAll(res.cacheDir, 0o700); err != nil {
		return nil, err
	}

	cache, err := lru.NewSynced[host.FileID, cachedMarker](
		opt.SizeEntries,
		func(id host.FileID,
		) uint32 {
			return uint32(id)
		})
	cache.SetOnEvict(func(id host.FileID, marker cachedMarker) {
		if marker == erroredMarker {
			return
		}
		filePath := res.tableFilePath(id)
		l.WithFields(logrus.Fields{
			"file": filePath,
		}).Debug("symbcache evicting")
		if err = os.Remove(filePath); err != nil {
			l.Error(err)
		}
	})
	if err != nil {
		return nil, err
	}
	res.cache = cache

	err = filepath.Walk(res.cacheDir, func(path string, info os.FileInfo, err error) error {
		if err != nil {
			return err
		}
		if info.IsDir() {
			return nil
		}
		filename := filepath.Base(path)
		id, err := FileIDFromStringNoQuotes(filename)
		if err != nil {
			return nil
		}
		id2 := id.StringNoQuotes()
		if filename != id2 {
			return nil
		}
		res.cache.Add(id, cached)
		return nil
	})
	if err != nil {
		return nil, err
	}

	res.wg.Add(1)
	go func() {
		defer res.wg.Done()
		convertLoop(res, shutdown)
	}()

	return res, nil
}

func convertLoop(res *Resolver, shutdown <-chan struct{}) {
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()

	for {
		select {
		case <-shutdown:
			for len(res.jobs) > 0 {
				job := <-res.jobs
				job.result <- res.convertSync(job.src, job.dst)
			}
			return
		case job := <-res.jobs:
			job.result <- res.convertSync(job.src, job.dst)
		}
	}
}

func (c *Resolver) ExecutableKnown(id host.FileID) bool {
	_, known := c.cache.Get(id)
	return known
}

func (c *Resolver) ObserveExecutable(fid host.FileID, elfRef *pfelf.Reference) error {
	o, ok := elfRef.ELFOpener.(pfelf.RootFSOpener)
	if !ok {
		return nil
	}
	if elfRef.FileName() == process.VdsoPathName.String() {
		c.cache.Add(fid, cached)
		return nil
	}

	pid := 0
	if pp, ok := elfRef.ELFOpener.(process.Process); ok {
		pid = int(pp.PID())
	}
	l := c.logger.WithFields(logrus.Fields{
		"fid": fid.StringNoQuotes(),
		"elf": elfRef.FileName(),
		"pid": pid,
	})
	t1 := time.Now()
	err := c.convert(l, fid, elfRef, o)
	if err != nil {
		c.cache.Add(fid, erroredMarker)
		l = l.WithError(err).WithField("duration", time.Since(t1))
		if !errors.Is(err, syscall.ESRCH) && !errors.Is(err, os.ErrNotExist) && !errors.Is(err, elf.ErrNoSymbols) {
			l.Error("conversion failed")
		} else {
			l.Debug("conversion failed")
		}
	} else {
		l.WithField("duration", time.Since(t1)).Debug("converted")
	}
	return err
}

func (c *Resolver) convert(
	l *logrus.Entry,
	fid host.FileID,
	elfRef *pfelf.Reference,
	o pfelf.RootFSOpener,
) error {
	var err error
	var dst *os.File
	var src *os.File

	tableFilePath := c.tableFilePath(fid)
	info, err := os.Stat(tableFilePath)
	if err == nil && info != nil {
		return nil
	}

	elf, err := elfRef.GetELF()
	if err != nil {
		return err
	}
	defer elf.Close()
	debugLinkFileName := elf.DebuglinkFileName(elfRef.FileName(), elfRef)
	if debugLinkFileName != "" {
		src, err = o.OpenRootFSFile(debugLinkFileName)
		if err != nil {
			l.WithError(err).Debug("open debug file")
		} else {
			defer src.Close()
		}
	}
	if src == nil {
		src = elf.OSFile()
	}
	if src == nil {
		return errors.New("failed to open elf os file")
	}

	dst, err = os.Create(tableFilePath)
	if err != nil {
		return err
	}
	defer dst.Close()

	err = c.convertAsync(src, dst)

	if err != nil {
		_ = os.Remove(tableFilePath)
		return err
	}
	c.cache.Add(fid, cached)
	return nil
}

func (c *Resolver) convertAsync(src, dst *os.File) error {
	job := convertJob{src: src, dst: dst, result: make(chan error)}
	c.jobs <- job
	return <-job.result
}

func (c *Resolver) convertSync(src, dst *os.File) error {
	return c.f.ConvertTable(src, dst)
}

func (c *Resolver) tableFilePath(fid host.FileID) string {
	return filepath.Join(c.cacheDir, fid.StringNoQuotes())
}

func (c *Resolver) ResolveAddress(
	fid host.FileID,
	addr uint64,
) (samples.SourceInfo, error) {
	c.mutex.Lock()
	defer c.mutex.Unlock()
	v, known := c.cache.Get(fid)

	if !known || v == erroredMarker {
		return samples.SourceInfo{}, errUnknownFile
	}
	t, ok := c.tables[fid]
	if ok {
		return t.Lookup(addr)
	}
	path := c.tableFilePath(fid)
	t, err := c.f.OpenTable(path)
	if err != nil {
		_ = os.Remove(path)
		c.cache.Remove(fid)
		return samples.SourceInfo{}, err
	}
	c.tables[fid] = t
	return t.Lookup(addr)
}

func (c *Resolver) Close() error {
	c.mutex.Lock()
	if c.shutdown != nil {
		close(c.shutdown)
		c.shutdown = nil
	}
	c.mutex.Unlock()

	c.wg.Wait()

	c.mutex.Lock()
	defer c.mutex.Unlock()

	for _, table := range c.tables {
		table.Close()
	}
	clear(c.tables)
	return nil
}

func FileIDFromStringNoQuotes(s string) (host.FileID, error) {
	if len(s) != 32 {
		return host.FileID(0), fmt.Errorf("invalid length for FileID string '%s': %d (expected 32)", s, len(s))
	}

	// Parse the first 16 hex characters as uint64
	val, err := strconv.ParseUint(s[:16], 16, 64)
	if err != nil {
		return host.FileID(0), fmt.Errorf("failed to parse FileID string '%s': %v", s, err)
	}

	return host.FileID(val), nil
}
