/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.headless.chat.mapper;

import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
import com.tencent.supersonic.headless.chat.mapper.BaseMatchStrategy;
import com.tencent.supersonic.headless.chat.mapper.MapperConfig;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class EmbeddingMatchStrategy
extends BaseMatchStrategy<EmbeddingResult> {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingMatchStrategy.class);
    @Autowired
    private MetaEmbeddingService metaEmbeddingService;

    @Override
    public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
        return this.getMapKey(oneRoundResult).equals(this.getMapKey(existResult)) && existResult.getDistance() > oneRoundResult.getDistance();
    }

    @Override
    public String getMapKey(EmbeddingResult a) {
        return a.getName() + "_" + a.getId();
    }

    @Override
    public void detectByStep(ChatQueryContext chatQueryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds, String detectSegment, int offset) {
    }

    @Override
    protected void detectByBatch(ChatQueryContext chatQueryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds, Set<String> detectSegments) {
        int embeddingMapperMin = Integer.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN));
        int embeddingMapperMax = Integer.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX));
        int embeddingMapperBatch = Integer.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
        List queryTextsList = detectSegments.stream().map(detectSegment -> detectSegment.trim()).filter(detectSegment -> StringUtils.isNotBlank((CharSequence)detectSegment) && detectSegment.length() >= embeddingMapperMin && detectSegment.length() <= embeddingMapperMax).collect(Collectors.toList());
        List queryTextsSubList = Lists.partition(queryTextsList, (int)embeddingMapperBatch);
        for (List queryTextsSub : queryTextsSubList) {
            this.detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext);
        }
    }

    private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds, List<String> queryTextsSub, ChatQueryContext chatQueryContext) {
        int embeddingNumber;
        Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
        double embeddingThreshold = Double.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_THRESHOLD));
        double embeddingThresholdMin = Double.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN));
        double threshold = this.getThreshold(embeddingThreshold, embeddingThresholdMin, chatQueryContext.getMapModeEnum());
        RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
        List<RetrieveQueryResult> retrieveQueryResults = this.metaEmbeddingService.retrieveQuery(retrieveQuery, embeddingNumber = Integer.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_NUMBER)).intValue(), modelIdToDataSetIds, detectDataSetIds);
        if (CollectionUtils.isEmpty(retrieveQueryResults)) {
            return;
        }
        List collect = retrieveQueryResults.stream().map(retrieveQueryResult -> {
            List retrievals = retrieveQueryResult.getRetrieval();
            if (CollectionUtils.isNotEmpty((Collection)retrievals)) {
                retrievals.removeIf(retrieval -> {
                    if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
                        return retrieval.getDistance() > 1.0 - threshold;
                    }
                    return false;
                });
            }
            return retrieveQueryResult;
        }).filter(retrieveQueryResult -> CollectionUtils.isNotEmpty((Collection)retrieveQueryResult.getRetrieval())).flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream().map(retrieval -> {
            EmbeddingResult embeddingResult = new EmbeddingResult();
            BeanUtils.copyProperties((Object)retrieval, (Object)embeddingResult);
            embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
            embeddingResult.setName(retrieval.getQuery());
            Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toString()));
            embeddingResult.setMetadata(convertedMap);
            return embeddingResult;
        })).collect(Collectors.toList());
        int embeddingRoundNumber = Integer.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER));
        int roundNumber = embeddingRoundNumber * queryTextsSub.size();
        List oneRoundResults = collect.stream().sorted(Comparator.comparingDouble(EmbeddingResult::getDistance)).limit(roundNumber).collect(Collectors.toList());
        this.selectResultInOneRound(results, oneRoundResults);
    }
}

