/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.neuralsearch.sparse.accessor.SparseVectorForwardIndex;
import org.opensearch.neuralsearch.sparse.accessor.SparseVectorReader;
import org.opensearch.neuralsearch.sparse.cache.CacheGatedForwardIndexReader;
import org.opensearch.neuralsearch.sparse.cache.CacheKey;
import org.opensearch.neuralsearch.sparse.cache.ForwardIndexCache;
import org.opensearch.neuralsearch.sparse.cache.ForwardIndexCacheItem;
import org.opensearch.neuralsearch.sparse.codec.SparseBinaryDocValuesPassThrough;
import org.opensearch.neuralsearch.sparse.common.PredicateUtils;
import org.opensearch.neuralsearch.sparse.quantization.ByteQuantizationUtil;
import org.opensearch.neuralsearch.sparse.query.ExactMatchScorer;
import org.opensearch.neuralsearch.sparse.query.OrderedPostingWithClustersScorer;
import org.opensearch.neuralsearch.sparse.query.SparseVectorQuery;

public class SparseQueryWeight
extends Weight {
    @Generated
    private static final Logger log = LogManager.getLogger(SparseQueryWeight.class);
    private final float boost;
    private final Weight fallbackQueryWeight;
    private final ForwardIndexCache forwardIndexCache;

    public SparseQueryWeight(SparseVectorQuery query, IndexSearcher searcher, ScoreMode scoreMode, float boost, ForwardIndexCache forwardIndexCache) throws IOException {
        super((Query)query);
        this.boost = boost;
        this.forwardIndexCache = forwardIndexCache;
        this.fallbackQueryWeight = query.getFallbackQuery().createWeight(searcher, scoreMode, boost);
    }

    public Explanation explain(LeafReaderContext context, int doc) throws IOException {
        return null;
    }

    public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
        SparseVectorQuery query = (SparseVectorQuery)this.parentQuery;
        SegmentInfo info = Lucene.segmentReader((LeafReader)context.reader()).getSegmentInfo().info;
        FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo(query.getFieldName());
        if (!PredicateUtils.shouldRunSeisPredicate.test(info, fieldInfo)) {
            return this.fallbackQueryWeight.scorerSupplier(context);
        }
        final Scorer scorer = this.selectScorer(query, context, info);
        return new ScorerSupplier(this){

            public Scorer get(long leadCost) throws IOException {
                return scorer;
            }

            public BulkScorer bulkScorer() throws IOException {
                return new BulkScorer(){

                    public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
                        collector.setScorer((Scorable)scorer);
                        DocIdSetIterator iter = scorer.iterator();
                        int docId = iter.nextDoc();
                        while (docId != Integer.MAX_VALUE) {
                            collector.collect(docId);
                            docId = iter.nextDoc();
                        }
                        return Integer.MAX_VALUE;
                    }

                    public long cost() {
                        return 0L;
                    }
                };
            }

            public long cost() {
                return 0L;
            }
        };
    }

    @VisibleForTesting
    Scorer selectScorer(SparseVectorQuery query, LeafReaderContext context, SegmentInfo segmentInfo) throws IOException {
        BitSet filter;
        SparseVectorReader cacheGatedForwardIndexReader = SparseVectorReader.NOOP_READER;
        FieldInfo fieldInfo = context.reader().getFieldInfos().fieldInfo(query.getFieldName());
        float rescaledBoost = this.boost * ByteQuantizationUtil.getCeilingValueIngest(fieldInfo) * ByteQuantizationUtil.getCeilingValueSearch(fieldInfo) / 255.0f / 255.0f;
        if (segmentInfo != null) {
            CacheKey key = new CacheKey(segmentInfo, query.getFieldName());
            ForwardIndexCacheItem cacheItem = this.forwardIndexCache.getOrCreate(key, segmentInfo.maxDoc());
            cacheGatedForwardIndexReader = this.getCacheGatedForwardIndexReader(cacheItem, context.reader(), query.getFieldName());
        }
        Similarity.SimScorer simScorer = ByteQuantizationUtil.getSimScorer(rescaledBoost);
        BitSetIterator filterBitIterator = null;
        if (query.getFilterResults() != null && (filter = query.getFilterResults().get(context.id())) != null) {
            int ord = filter.cardinality();
            filterBitIterator = new BitSetIterator(filter, (long)ord);
            if (ord <= query.getQueryContext().getK()) {
                return new ExactMatchScorer(filterBitIterator, query.getQueryVector(), cacheGatedForwardIndexReader, simScorer);
            }
        }
        return new OrderedPostingWithClustersScorer(query.getFieldName(), query.getQueryContext(), query.getQueryVector(), context.reader(), context.reader().getLiveDocs(), cacheGatedForwardIndexReader, simScorer, filterBitIterator);
    }

    private SparseVectorReader getCacheGatedForwardIndexReader(SparseVectorForwardIndex index, LeafReader leafReader, String fieldName) throws IOException {
        BinaryDocValues docValues = leafReader.getBinaryDocValues(fieldName);
        if (docValues instanceof SparseBinaryDocValuesPassThrough) {
            SparseBinaryDocValuesPassThrough sparseBinaryDocValuesPassThrough = (SparseBinaryDocValuesPassThrough)docValues;
            return new CacheGatedForwardIndexReader(index.getReader(), index.getWriter(), sparseBinaryDocValuesPassThrough);
        }
        return SparseVectorReader.NOOP_READER;
    }

    public boolean isCacheable(LeafReaderContext ctx) {
        return false;
    }
}

