/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;

import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.ParseMode;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

public class EmbeddingRecallRecognizer
extends PluginRecognizer {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingRecallRecognizer.class);

    @Override
    public boolean checkPreCondition(ChatParseContext chatParseContext) {
        List<ChatPlugin> plugins = this.getPluginList(chatParseContext);
        return !CollectionUtils.isEmpty(plugins);
    }

    @Override
    public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
        String text = chatParseContext.getQueryText();
        List<Retrieval> embeddingRetrievals = this.embeddingRecall(text);
        if (CollectionUtils.isEmpty(embeddingRetrievals)) {
            return null;
        }
        List<ChatPlugin> plugins = this.getPluginList(chatParseContext);
        Map<Long, ChatPlugin> pluginMap = plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p));
        for (Retrieval embeddingRetrieval : embeddingRetrievals) {
            Set dataSetList;
            ChatPlugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
            if (plugin == null) continue;
            Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, chatParseContext);
            log.info("embedding plugin resolve: {}", pair);
            if (!((Boolean)pair.getLeft()).booleanValue() || CollectionUtils.isEmpty((Collection)(dataSetList = (Set)pair.getRight()))) continue;
            plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
            double distance = embeddingRetrieval.getDistance();
            double score = (double)chatParseContext.getQueryText().length() * (1.0 - distance);
            return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
        }
        return null;
    }

    public List<Retrieval> embeddingRecall(String embeddingText) {
        try {
            PluginManager pluginManager = (PluginManager)ContextUtils.getBean(PluginManager.class);
            RetrieveQueryResult embeddingResp = pluginManager.recognize(embeddingText);
            List embeddingRetrievals = embeddingResp.getRetrieval();
            if (!CollectionUtils.isEmpty((Collection)embeddingRetrievals)) {
                embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o -> Math.abs(o.getDistance()))).collect(Collectors.toList());
                embeddingResp.setRetrieval(embeddingRetrievals);
            }
            return embeddingRetrievals;
        }
        catch (Exception e) {
            log.warn("get embedding result error ", (Throwable)e);
            return Lists.newArrayList();
        }
    }
}

