package com.tencent.supersonic.chat.server.plugin.recognize.embedding;

import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.ParseMode;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.class */
public class EmbeddingRecallRecognizer extends PluginRecognizer {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingRecallRecognizer.class);

    @Override // com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer
    public boolean checkPreCondition(ChatParseContext chatParseContext) {
        return !CollectionUtils.isEmpty(getPluginList(chatParseContext));
    }

    @Override // com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer
    public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
        List<Retrieval> embeddingRecall = embeddingRecall(chatParseContext.getQueryText());
        if (CollectionUtils.isEmpty(embeddingRecall)) {
            return null;
        }
        Map map = (Map) getPluginList(chatParseContext).stream().collect(Collectors.toMap((v0) -> {
            return v0.getId();
        }, chatPlugin -> {
            return chatPlugin;
        }));
        for (Retrieval retrieval : embeddingRecall) {
            ChatPlugin chatPlugin2 = (ChatPlugin) map.get(Long.valueOf(Long.parseLong(retrieval.getId())));
            if (chatPlugin2 != null) {
                Pair<Boolean, Set<Long>> resolve = PluginManager.resolve(chatPlugin2, chatParseContext);
                log.info("embedding plugin resolve: {}", resolve);
                if (((Boolean) resolve.getLeft()).booleanValue()) {
                    Set<Long> set = (Set) resolve.getRight();
                    if (!CollectionUtils.isEmpty(set)) {
                        chatPlugin2.setParseMode(ParseMode.EMBEDDING_RECALL);
                        double distance = retrieval.getDistance();
                        return PluginRecallResult.builder().plugin(chatPlugin2).dataSetIds(set).score(chatParseContext.getQueryText().length() * (1.0d - distance)).distance(distance).build();
                    }
                } else {
                    continue;
                }
            }
        }
        return null;
    }

    public List<Retrieval> embeddingRecall(String str) {
        try {
            RetrieveQueryResult recognize = ((PluginManager) ContextUtils.getBean(PluginManager.class)).recognize(str);
            List<Retrieval> retrieval = recognize.getRetrieval();
            if (!CollectionUtils.isEmpty(retrieval)) {
                retrieval = (List) retrieval.stream().sorted(Comparator.comparingDouble(retrieval2 -> {
                    return Math.abs(retrieval2.getDistance());
                })).collect(Collectors.toList());
                recognize.setRetrieval(retrieval);
            }
            return retrieval;
        } catch (Exception e) {
            log.warn("get embedding result error ", e);
            return Lists.newArrayList();
        }
    }
}
