/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.chat.server.plugin;

import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.PluginTool;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.build.ParamOption;
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

@Component
public class PluginManager {
    private static final Logger log = LoggerFactory.getLogger(PluginManager.class);
    @Autowired
    private EmbeddingConfig embeddingConfig;
    @Autowired
    private EmbeddingService embeddingService;

    public static List<ChatPlugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
        PluginService pluginService = (PluginService)ContextUtils.getBean(PluginService.class);
        Agent agent = chatParseContext.getAgent();
        List<ChatPlugin> plugins = pluginService.getPluginList();
        if (Objects.isNull((Object)agent)) {
            return plugins;
        }
        List pluginIds = PluginManager.getPluginTools(agent).stream().map(PluginTool::getPlugins).flatMap(Collection::stream).collect(Collectors.toList());
        if (org.apache.commons.collections.CollectionUtils.isEmpty(pluginIds)) {
            return Lists.newArrayList();
        }
        plugins = plugins.stream().filter(plugin -> pluginIds.contains(plugin.getId())).collect(Collectors.toList());
        log.info("plugins witch can be supported by cur agent :{} {}", (Object)agent.getName(), plugins.stream().map(ChatPlugin::getName).collect(Collectors.toList()));
        return plugins;
    }

    private static List<PluginTool> getPluginTools(Agent agent) {
        if (agent == null) {
            return Lists.newArrayList();
        }
        List<String> tools = agent.getTools(AgentToolType.PLUGIN);
        if (org.apache.commons.collections.CollectionUtils.isEmpty(tools)) {
            return Lists.newArrayList();
        }
        return tools.stream().map(tool -> (PluginTool)JSONObject.parseObject((String)tool, PluginTool.class)).collect(Collectors.toList());
    }

    @EventListener
    public void addPlugin(PluginAddEvent pluginAddEvent) {
        ChatPlugin plugin = pluginAddEvent.getPlugin();
        if (org.apache.commons.collections.CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) {
            this.requestEmbeddingPluginAdd(this.convert(Lists.newArrayList((Object[])new ChatPlugin[]{plugin})));
        }
    }

    @EventListener
    public void updatePlugin(PluginUpdateEvent pluginUpdateEvent) {
        ChatPlugin oldPlugin = pluginUpdateEvent.getOldPlugin();
        ChatPlugin newPlugin = pluginUpdateEvent.getNewPlugin();
        if (org.apache.commons.collections.CollectionUtils.isNotEmpty(oldPlugin.getExampleQuestionList())) {
            this.requestEmbeddingPluginDelete(this.getEmbeddingId(Lists.newArrayList((Object[])new ChatPlugin[]{oldPlugin})));
        }
        if (org.apache.commons.collections.CollectionUtils.isNotEmpty(newPlugin.getExampleQuestionList())) {
            this.requestEmbeddingPluginAdd(this.convert(Lists.newArrayList((Object[])new ChatPlugin[]{newPlugin})));
        }
    }

    @EventListener
    public void delPlugin(PluginDelEvent pluginDelEvent) {
        ChatPlugin plugin = pluginDelEvent.getPlugin();
        if (org.apache.commons.collections.CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) {
            this.requestEmbeddingPluginDelete(this.getEmbeddingId(Lists.newArrayList((Object[])new ChatPlugin[]{plugin})));
        }
    }

    public void requestEmbeddingPluginDelete(Set<String> queryIds) {
        if (org.apache.commons.collections.CollectionUtils.isEmpty(queryIds)) {
            return;
        }
        String presetCollection = this.embeddingConfig.getPresetCollection();
        ArrayList<TextSegment> queries = new ArrayList<TextSegment>();
        for (String id : queryIds) {
            TextSegment query = TextSegment.from((String)"");
            TextSegmentConvert.addQueryId((TextSegment)query, (String)id);
            queries.add(query);
        }
        this.embeddingService.deleteQuery(presetCollection, queries);
    }

    public void requestEmbeddingPluginAdd(List<TextSegment> queries) {
        if (org.apache.commons.collections.CollectionUtils.isEmpty(queries)) {
            return;
        }
        String presetCollection = this.embeddingConfig.getPresetCollection();
        this.embeddingService.addQuery(presetCollection, queries);
    }

    public RetrieveQueryResult recognize(String embeddingText) {
        RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(embeddingText)).build();
        List resultList = this.embeddingService.retrieveQuery(this.embeddingConfig.getPresetCollection(), retrieveQuery, this.embeddingConfig.getNResult());
        if (org.apache.commons.collections.CollectionUtils.isNotEmpty((Collection)resultList)) {
            for (RetrieveQueryResult embeddingResp : resultList) {
                List embeddingRetrievals = embeddingResp.getRetrieval();
                for (Retrieval embeddingRetrieval : embeddingRetrievals) {
                    embeddingRetrieval.setId(this.getPluginIdFromEmbeddingId(embeddingRetrieval.getId()));
                }
            }
            return (RetrieveQueryResult)resultList.get(0);
        }
        throw new RuntimeException("get embedding result failed");
    }

    public List<TextSegment> convert(List<ChatPlugin> plugins) {
        ArrayList queries = Lists.newArrayList();
        for (ChatPlugin plugin : plugins) {
            List<String> exampleQuestions = plugin.getExampleQuestionList();
            int num = 0;
            for (String pattern : exampleQuestions) {
                TextSegment query = TextSegment.from((String)pattern);
                TextSegmentConvert.addQueryId((TextSegment)query, (String)this.generateUniqueEmbeddingId(num, plugin.getId()));
                queries.add(query);
                ++num;
            }
        }
        return queries;
    }

    private Set<String> getEmbeddingId(List<ChatPlugin> plugins) {
        HashSet<String> embeddingIdSet = new HashSet<String>();
        for (TextSegment query : this.convert(plugins)) {
            TextSegmentConvert.addQueryId((TextSegment)query, (String)TextSegmentConvert.getQueryId((TextSegment)query));
        }
        return embeddingIdSet;
    }

    private String generateUniqueEmbeddingId(int num, Long pluginId) {
        if (num < 10) {
            return String.format("%s00%s", pluginId, num);
        }
        return String.format("%s0%s", pluginId, num);
    }

    private String getPluginIdFromEmbeddingId(String id) {
        return String.valueOf(Integer.parseInt(id) / 1000);
    }

    public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) {
        SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
        Set<Long> pluginMatchedDataSet = PluginManager.getPluginMatchedDataSet(plugin, chatParseContext);
        if (org.apache.commons.collections.CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
            return Pair.of((Object)false, (Object)Sets.newHashSet());
        }
        List<ParamOption> paramOptions = PluginManager.getSemanticOption(plugin);
        if (org.apache.commons.collections.CollectionUtils.isEmpty(paramOptions)) {
            return Pair.of((Object)true, pluginMatchedDataSet);
        }
        HashSet matchedDataSet = Sets.newHashSet();
        Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream().collect(Collectors.groupingBy(ParamOption::getDataSetId));
        for (Long dataSetId : paramOptionMap.keySet()) {
            List<ParamOption> params = paramOptionMap.get(dataSetId);
            if (org.apache.commons.collections.CollectionUtils.isEmpty(params)) {
                matchedDataSet.add(dataSetId);
                continue;
            }
            boolean matched = true;
            for (ParamOption paramOption : params) {
                Set<Long> elementIdSet = PluginManager.getSchemaElementMatch(dataSetId, schemaMapInfo);
                if (org.apache.commons.collections.CollectionUtils.isEmpty(elementIdSet)) {
                    matched = false;
                    break;
                }
                if (elementIdSet.contains(paramOption.getElementId())) continue;
                matched = false;
                break;
            }
            if (!matched) continue;
            matchedDataSet.add(dataSetId);
        }
        if (org.apache.commons.collections.CollectionUtils.isEmpty((Collection)matchedDataSet)) {
            return Pair.of((Object)false, (Object)Sets.newHashSet());
        }
        return Pair.of((Object)true, (Object)matchedDataSet);
    }

    private static Set<Long> getSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
        List schemaElementMatches = schemaMapInfo.getMatchedElements(modelId);
        if (CollectionUtils.isEmpty((Collection)schemaElementMatches)) {
            return Sets.newHashSet();
        }
        return schemaElementMatches.stream().filter(schemaElementMatch -> SchemaElementType.VALUE.equals((Object)schemaElementMatch.getElement().getType()) || SchemaElementType.ID.equals((Object)schemaElementMatch.getElement().getType())).map(SchemaElementMatch::getElement).map(SchemaElement::getId).collect(Collectors.toSet());
    }

    private static List<ParamOption> getSemanticOption(ChatPlugin plugin) {
        WebBase webBase = (WebBase)JSONObject.parseObject((String)plugin.getConfig(), WebBase.class);
        if (Objects.isNull(webBase)) {
            return null;
        }
        List<ParamOption> paramOptions = webBase.getParamOptions();
        if (org.apache.commons.collections.CollectionUtils.isEmpty(paramOptions)) {
            return Lists.newArrayList();
        }
        return paramOptions.stream().filter(paramOption -> ParamOption.ParamType.SEMANTIC.equals((Object)paramOption.getParamType())).collect(Collectors.toList());
    }

    private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) {
        Set matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
        if (plugin.isContainsAllDataSet()) {
            return Sets.newHashSet((Object[])new Long[]{plugin.getDefaultMode()});
        }
        List<Long> dataSetList = plugin.getDataSetList();
        HashSet pluginMatchedDataSet = Sets.newHashSet();
        for (Long dataSetId : dataSetList) {
            if (!matchedDataSets.contains(dataSetId)) continue;
            pluginMatchedDataSet.add(dataSetId);
        }
        return pluginMatchedDataSet;
    }
}

