package com.tencent.supersonic.headless.chat.mapper;

import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
/* loaded from: input_file:com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.class */
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingMatchStrategy.class);

    @Autowired
    private MetaEmbeddingService metaEmbeddingService;

    @Override // com.tencent.supersonic.headless.chat.mapper.BaseMatchStrategy
    public boolean needDelete(EmbeddingResult embeddingResult, EmbeddingResult embeddingResult2) {
        return getMapKey(embeddingResult).equals(getMapKey(embeddingResult2)) && embeddingResult2.getDistance() > embeddingResult.getDistance();
    }

    @Override // com.tencent.supersonic.headless.chat.mapper.BaseMatchStrategy
    public String getMapKey(EmbeddingResult embeddingResult) {
        return embeddingResult.getName() + "_" + embeddingResult.getId();
    }

    @Override // com.tencent.supersonic.headless.chat.mapper.BaseMatchStrategy
    public void detectByStep(ChatQueryContext chatQueryContext, Set<EmbeddingResult> set, Set<Long> set2, String str, int i) {
    }

    @Override // com.tencent.supersonic.headless.chat.mapper.BaseMatchStrategy
    protected void detectByBatch(ChatQueryContext chatQueryContext, Set<EmbeddingResult> set, Set<Long> set2, Set<String> set3) {
        int intValue = Integer.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN)).intValue();
        int intValue2 = Integer.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX)).intValue();
        Iterator it = Lists.partition((List) set3.stream().map((v0) -> {
            return v0.trim();
        }).filter(str -> {
            return StringUtils.isNotBlank(str) && str.length() >= intValue && str.length() <= intValue2;
        }).collect(Collectors.toList()), Integer.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)).intValue()).iterator();
        while (it.hasNext()) {
            detectByQueryTextsSub(set, set2, (List) it.next(), chatQueryContext);
        }
    }

    private void detectByQueryTextsSub(Set<EmbeddingResult> set, Set<Long> set2, List<String> list, ChatQueryContext chatQueryContext) {
        Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
        double threshold = getThreshold(Double.valueOf(Double.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_THRESHOLD)).doubleValue()), Double.valueOf(Double.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN)).doubleValue()), chatQueryContext.getMapModeEnum());
        List<RetrieveQueryResult> retrieveQuery = this.metaEmbeddingService.retrieveQuery(RetrieveQuery.builder().queryTextsList(list).build(), Integer.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_NUMBER)).intValue(), modelIdToDataSetIds, set2);
        if (CollectionUtils.isEmpty(retrieveQuery)) {
            return;
        }
        selectResultInOneRound(set, (List) ((List) retrieveQuery.stream().map(retrieveQueryResult -> {
            List retrieval = retrieveQueryResult.getRetrieval();
            if (CollectionUtils.isNotEmpty(retrieval)) {
                retrieval.removeIf(retrieval2 -> {
                    return !retrieveQueryResult.getQuery().contains(retrieval2.getQuery()) && retrieval2.getDistance() > 1.0d - threshold;
                });
            }
            return retrieveQueryResult;
        }).filter(retrieveQueryResult2 -> {
            return CollectionUtils.isNotEmpty(retrieveQueryResult2.getRetrieval());
        }).flatMap(retrieveQueryResult3 -> {
            return retrieveQueryResult3.getRetrieval().stream().map(retrieval -> {
                EmbeddingResult embeddingResult = new EmbeddingResult();
                BeanUtils.copyProperties(retrieval, embeddingResult);
                embeddingResult.setDetectWord(retrieveQueryResult3.getQuery());
                embeddingResult.setName(retrieval.getQuery());
                embeddingResult.setMetadata((Map) retrieval.getMetadata().entrySet().stream().collect(Collectors.toMap((v0) -> {
                    return v0.getKey();
                }, entry -> {
                    return entry.getValue().toString();
                })));
                return embeddingResult;
            });
        }).collect(Collectors.toList())).stream().sorted(Comparator.comparingDouble((v0) -> {
            return v0.getDistance();
        })).limit(Integer.valueOf(this.mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER)).intValue() * list.size()).collect(Collectors.toList()));
    }
}
