/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.mmr;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.search.processor.mmr.MMRQueryTransformer;
import org.opensearch.knn.search.processor.mmr.MMRRerankContext;
import org.opensearch.knn.search.processor.mmr.MMRTransformContext;
import org.opensearch.knn.search.processor.mmr.MMRUtil;
import org.opensearch.knn.search.processor.mmr.MMRVectorFieldInfo;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.stats.events.EventStatName;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;
import org.opensearch.neuralsearch.util.SemanticMappingUtils;
import org.opensearch.transport.client.Client;

public class MMRNeuralQueryTransformer
implements MMRQueryTransformer<NeuralQueryBuilder> {
    public void transform(NeuralQueryBuilder queryBuilder, ActionListener<Void> listener, MMRTransformContext mmrTransformContext) {
        try {
            EventStatsManager.increment(EventStatName.MMR_NEURAL_QUERY_TRANSFORMER);
            if (queryBuilder.maxDistance() == null && queryBuilder.minScore() == null) {
                queryBuilder.k(mmrTransformContext.getCandidates());
            }
            if (mmrTransformContext.isVectorFieldInfoResolved()) {
                listener.onResponse(null);
                return;
            }
            List remoteIndices = mmrTransformContext.getRemoteIndices();
            if (!remoteIndices.isEmpty()) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] in the mmr search extension should be provided for remote indices [%s].", "vector_field_path", String.join((CharSequence)",", remoteIndices)));
            }
            MMRRerankContext mmrRerankContext = mmrTransformContext.getMmrRerankContext();
            String queryFieldName = queryBuilder.fieldName();
            if (queryFieldName == null) {
                throw new IllegalArgumentException("Failed to transform the neural query for MMR. Query field name should not be null.");
            }
            List<MMRVectorFieldInfo> vectorFieldInfos = this.collectVectorFieldInfos(queryFieldName, mmrTransformContext.getLocalIndexMetadataList());
            HashMap<String, String> indexToFieldPathMap = new HashMap<String, String>();
            HashSet<String> uniqueFieldPaths = new HashSet<String>();
            for (MMRVectorFieldInfo info : vectorFieldInfos) {
                indexToFieldPathMap.put(info.getIndexName(), info.getFieldPath());
                uniqueFieldPaths.add(info.getFieldPath());
            }
            if (uniqueFieldPaths.size() == 1) {
                mmrRerankContext.setVectorFieldPath((String)uniqueFieldPaths.iterator().next());
            } else {
                mmrRerankContext.setIndexToVectorFieldPathMap(indexToFieldPathMap);
            }
            MMRUtil.resolveKnnVectorFieldInfo(vectorFieldInfos, (SpaceType)mmrTransformContext.getUserProvidedSpaceType(), (VectorDataType)mmrTransformContext.getUserProvidedVectorDataType(), (Client)mmrTransformContext.getClient(), (ActionListener)ActionListener.wrap(vectorFieldInfo -> {
                mmrRerankContext.setVectorDataType(vectorFieldInfo.getVectorDataType());
                mmrRerankContext.setSpaceType(vectorFieldInfo.getSpaceType());
                listener.onResponse(null);
            }, arg_0 -> listener.onFailure(arg_0)));
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    public String getQueryName() {
        return "neural";
    }

    private List<MMRVectorFieldInfo> collectVectorFieldInfos(String queryFieldPath, List<IndexMetadata> indexMetadataList) {
        ArrayList<MMRVectorFieldInfo> vectorFieldInfos = new ArrayList<MMRVectorFieldInfo>();
        for (IndexMetadata indexMetadata : indexMetadataList) {
            vectorFieldInfos.add(this.collectKnnVectorFieldInfo(indexMetadata, queryFieldPath));
        }
        return vectorFieldInfos;
    }

    private MMRVectorFieldInfo collectKnnVectorFieldInfo(IndexMetadata indexMetadata, String queryFieldPath) {
        MMRVectorFieldInfo vectorFieldInfo = new MMRVectorFieldInfo();
        vectorFieldInfo.setIndexNameByIndexMetadata(indexMetadata);
        MappingMetadata mappingMetadata = indexMetadata.mapping();
        if (mappingMetadata == null) {
            vectorFieldInfo.setUnmapped(true);
            return vectorFieldInfo;
        }
        Map mapping = mappingMetadata.sourceAsMap();
        Map queryFieldConfig = MMRUtil.getMMRFieldMappingByPath((Map)mapping, (String)queryFieldPath);
        if (queryFieldConfig == null) {
            vectorFieldInfo.setUnmapped(true);
            return vectorFieldInfo;
        }
        vectorFieldInfo.setUnmapped(false);
        vectorFieldInfo.setFieldPath(queryFieldPath);
        String fieldType = (String)queryFieldConfig.get("type");
        vectorFieldInfo.setFieldType(fieldType);
        Map knnVectorFieldConfig = queryFieldConfig;
        if ("semantic".equals(fieldType)) {
            Object chunkingConfig = queryFieldConfig.get("chunking");
            if (chunkingConfig != null && !Boolean.FALSE.equals(chunkingConfig)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Field [%s] is a semantic field with chunking enabled, which can produce multiple vectors per document. MMR reranking does not support multiple vectors per document.", queryFieldPath));
            }
            String semanticInfoFieldPath = SemanticMappingUtils.getSemanticInfoFieldFullPath(queryFieldConfig, queryFieldPath, queryFieldPath);
            String vectorFieldPath = semanticInfoFieldPath + ".embedding";
            knnVectorFieldConfig = MMRUtil.getMMRFieldMappingByPath((Map)mapping, (String)vectorFieldPath);
            if (knnVectorFieldConfig == null) {
                throw new IllegalStateException(String.format(Locale.ROOT, "Failed to find the vector field [%s] from index mapping for the semantic field [%s] when transform the neural query for MMR.", vectorFieldPath, queryFieldPath));
            }
            String vectorFieldType = (String)knnVectorFieldConfig.get("type");
            if (!"knn_vector".equals(vectorFieldType)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Field [%s] is a semantic field with a non-KNN embedding [%s]. MMR reranking only can support knn_vector field.", queryFieldPath, vectorFieldType));
            }
            vectorFieldInfo.setFieldType(vectorFieldType);
            vectorFieldInfo.setFieldPath(vectorFieldPath);
        } else if (!"knn_vector".equals(fieldType)) {
            return vectorFieldInfo;
        }
        vectorFieldInfo.setKnnConfig(knnVectorFieldConfig);
        return vectorFieldInfo;
    }

    @Generated
    public MMRNeuralQueryTransformer() {
    }
}

