/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

public final class HnswGraphBuilder<T> {
    private static final long DEFAULT_RAND_SEED = 42L;
    public static final String HNSW_COMPONENT = "HNSW";
    public static long randSeed = 42L;
    private final int M;
    private final int beamWidth;
    private final double ml;
    private final NeighborArray scratch;
    private final VectorSimilarityFunction similarityFunction;
    private final VectorEncoding vectorEncoding;
    private final RandomAccessVectorValues<T> vectors;
    private final SplittableRandom random;
    private final HnswGraphSearcher<T> graphSearcher;
    final OnHeapHnswGraph hnsw;
    private InfoStream infoStream = InfoStream.getDefault();
    private final RandomAccessVectorValues<T> vectorsCopy;

    public static <T> HnswGraphBuilder<T> create(RandomAccessVectorValues<T> vectors, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction, int M, int beamWidth, long seed) throws IOException {
        return new HnswGraphBuilder<T>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
    }

    private HnswGraphBuilder(RandomAccessVectorValues<T> vectors, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction, int M, int beamWidth, long seed) throws IOException {
        this.vectors = vectors;
        this.vectorsCopy = vectors.copy();
        this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
        this.similarityFunction = Objects.requireNonNull(similarityFunction);
        if (M <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (beamWidth <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.M = M;
        this.beamWidth = beamWidth;
        this.ml = M == 1 ? 1.0 : 1.0 / Math.log(1.0 * (double)M);
        this.random = new SplittableRandom(seed);
        int levelOfFirstNode = HnswGraphBuilder.getRandomGraphLevel(this.ml, this.random);
        this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
        this.graphSearcher = new HnswGraphSearcher(vectorEncoding, similarityFunction, new NeighborQueue(beamWidth, true), new FixedBitSet(this.vectors.size()));
        this.scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
    }

    public OnHeapHnswGraph build(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
        if (vectorsToAdd == this.vectors) {
            throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
        }
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "build graph from " + vectorsToAdd.size() + " vectors");
        }
        this.addVectors(vectorsToAdd);
        return this.hnsw;
    }

    private void addVectors(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
        long start;
        long t = start = System.nanoTime();
        for (int node = 1; node < vectorsToAdd.size(); ++node) {
            this.addGraphNode(node, vectorsToAdd);
            if (node % 10000 != 0 || !this.infoStream.isEnabled(HNSW_COMPONENT)) continue;
            t = this.printGraphBuildStatus(node, start, t);
        }
    }

    public void setInfoStream(InfoStream infoStream) {
        this.infoStream = infoStream;
    }

    public OnHeapHnswGraph getGraph() {
        return this.hnsw;
    }

    public void addGraphNode(int node, T value) throws IOException {
        NeighborQueue candidates;
        int level;
        int nodeLevel = HnswGraphBuilder.getRandomGraphLevel(this.ml, this.random);
        int curMaxLevel = this.hnsw.numLevels() - 1;
        int[] eps = new int[]{this.hnsw.entryNode()};
        for (level = nodeLevel; level > curMaxLevel; --level) {
            this.hnsw.addNode(level, node);
        }
        for (level = curMaxLevel; level > nodeLevel; --level) {
            candidates = this.graphSearcher.searchLevel(value, 1, level, eps, this.vectors, this.hnsw);
            eps = new int[]{candidates.pop()};
        }
        for (level = Math.min(nodeLevel, curMaxLevel); level >= 0; --level) {
            candidates = this.graphSearcher.searchLevel(value, this.beamWidth, level, eps, this.vectors, this.hnsw);
            eps = candidates.nodes();
            this.hnsw.addNode(level, node);
            this.addDiverseNeighbors(level, node, candidates);
        }
    }

    public void addGraphNode(int node, RandomAccessVectorValues<T> values2) throws IOException {
        this.addGraphNode(node, values2.vectorValue(node));
    }

    private long printGraphBuildStatus(int node, long start, long t) {
        long now = System.nanoTime();
        this.infoStream.message(HNSW_COMPONENT, String.format(Locale.ROOT, "built %d in %d/%d ms", node, TimeUnit.NANOSECONDS.toMillis(now - t), TimeUnit.NANOSECONDS.toMillis(now - start)));
        return now;
    }

    private void addDiverseNeighbors(int level, int node, NeighborQueue candidates) throws IOException {
        NeighborArray neighbors = this.hnsw.getNeighbors(level, node);
        assert (neighbors.size() == 0);
        this.popToScratch(candidates);
        int maxConnOnLevel = level == 0 ? this.M * 2 : this.M;
        this.selectAndLinkDiverse(neighbors, this.scratch, maxConnOnLevel);
        int size = neighbors.size();
        for (int i = 0; i < size; ++i) {
            int nbr = neighbors.node[i];
            NeighborArray nbrNbr = this.hnsw.getNeighbors(level, nbr);
            nbrNbr.insertSorted(node, neighbors.score[i]);
            if (nbrNbr.size() <= maxConnOnLevel) continue;
            int indexToRemove = this.findWorstNonDiverse(nbrNbr);
            nbrNbr.removeIndex(indexToRemove);
        }
    }

    private void selectAndLinkDiverse(NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException {
        for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; --i) {
            int cNode = candidates.node[i];
            float cScore = candidates.score[i];
            assert (cNode < this.hnsw.size());
            if (!this.diversityCheck(cNode, cScore, neighbors)) continue;
            neighbors.add(cNode, cScore);
        }
    }

    private void popToScratch(NeighborQueue candidates) {
        this.scratch.clear();
        int candidateCount = candidates.size();
        for (int i = 0; i < candidateCount; ++i) {
            float maxSimilarity = candidates.topScore();
            this.scratch.add(candidates.pop(), maxSimilarity);
        }
    }

    private boolean diversityCheck(int candidate, float score, NeighborArray neighbors) throws IOException {
        return this.isDiverse(candidate, neighbors, score);
    }

    private boolean isDiverse(int candidate, NeighborArray neighbors, float score) throws IOException {
        switch (this.vectorEncoding) {
            case BYTE: {
                return this.isDiverse((byte[])this.vectors.vectorValue(candidate), neighbors, score);
            }
        }
        return this.isDiverse((float[])this.vectors.vectorValue(candidate), neighbors, score);
    }

    private boolean isDiverse(float[] candidate, NeighborArray neighbors, float score) throws IOException {
        for (int i = 0; i < neighbors.size(); ++i) {
            float neighborSimilarity = this.similarityFunction.compare(candidate, (float[])this.vectorsCopy.vectorValue(neighbors.node[i]));
            if (!(neighborSimilarity >= score)) continue;
            return false;
        }
        return true;
    }

    private boolean isDiverse(byte[] candidate, NeighborArray neighbors, float score) throws IOException {
        for (int i = 0; i < neighbors.size(); ++i) {
            float neighborSimilarity = this.similarityFunction.compare(candidate, (byte[])this.vectorsCopy.vectorValue(neighbors.node[i]));
            if (!(neighborSimilarity >= score)) continue;
            return false;
        }
        return true;
    }

    private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
        for (int i = neighbors.size() - 1; i > 0; --i) {
            if (!this.isWorstNonDiverse(i, neighbors)) continue;
            return i;
        }
        return neighbors.size() - 1;
    }

    private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors) throws IOException {
        int candidateNode = neighbors.node[candidateIndex];
        switch (this.vectorEncoding) {
            case BYTE: {
                return this.isWorstNonDiverse(candidateIndex, (byte[])this.vectors.vectorValue(candidateNode), neighbors);
            }
        }
        return this.isWorstNonDiverse(candidateIndex, (float[])this.vectors.vectorValue(candidateNode), neighbors);
    }

    private boolean isWorstNonDiverse(int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
        float minAcceptedSimilarity = neighbors.score[candidateIndex];
        for (int i = candidateIndex - 1; i >= 0; --i) {
            float neighborSimilarity = this.similarityFunction.compare(candidateVector, (float[])this.vectorsCopy.vectorValue(neighbors.node[i]));
            if (!(neighborSimilarity >= minAcceptedSimilarity)) continue;
            return true;
        }
        return false;
    }

    private boolean isWorstNonDiverse(int candidateIndex, byte[] candidateVector, NeighborArray neighbors) throws IOException {
        float minAcceptedSimilarity = neighbors.score[candidateIndex];
        for (int i = candidateIndex - 1; i >= 0; --i) {
            float neighborSimilarity = this.similarityFunction.compare(candidateVector, (byte[])this.vectorsCopy.vectorValue(neighbors.node[i]));
            if (!(neighborSimilarity >= minAcceptedSimilarity)) continue;
            return true;
        }
        return false;
    }

    private static int getRandomGraphLevel(double ml, SplittableRandom random) {
        double randDouble;
        while ((randDouble = random.nextDouble()) == 0.0) {
        }
        return (int)(-Math.log(randDouble) * ml);
    }
}

