package com.tencent.supersonic.chat.server.plugin;

import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.PluginTool;
import com.tencent.supersonic.chat.server.plugin.build.ParamOption;
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Component;

@Component
/* loaded from: input_file:com/tencent/supersonic/chat/server/plugin/PluginManager.class */
public class PluginManager {
    private static final Logger log = LoggerFactory.getLogger(PluginManager.class);

    @Autowired
    private EmbeddingConfig embeddingConfig;

    @Autowired
    private EmbeddingService embeddingService;

    public static List<ChatPlugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
        PluginService pluginService = (PluginService) ContextUtils.getBean(PluginService.class);
        Agent agent = chatParseContext.getAgent();
        List<ChatPlugin> pluginList = pluginService.getPluginList();
        if (Objects.isNull(agent)) {
            return pluginList;
        }
        List list = (List) getPluginTools(agent).stream().map((v0) -> {
            return v0.getPlugins();
        }).flatMap((v0) -> {
            return v0.stream();
        }).collect(Collectors.toList());
        if (CollectionUtils.isEmpty(list)) {
            return Lists.newArrayList();
        }
        List<ChatPlugin> list2 = (List) pluginList.stream().filter(chatPlugin -> {
            return list.contains(chatPlugin.getId());
        }).collect(Collectors.toList());
        log.info("plugins witch can be supported by cur agent :{} {}", agent.getName(), list2.stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toList()));
        return list2;
    }

    private static List<PluginTool> getPluginTools(Agent agent) {
        if (agent == null) {
            return Lists.newArrayList();
        }
        List<String> tools = agent.getTools(AgentToolType.PLUGIN);
        return CollectionUtils.isEmpty(tools) ? Lists.newArrayList() : (List) tools.stream().map(str -> {
            return (PluginTool) JSONObject.parseObject(str, PluginTool.class);
        }).collect(Collectors.toList());
    }

    @EventListener
    public void addPlugin(PluginAddEvent pluginAddEvent) {
        ChatPlugin plugin = pluginAddEvent.getPlugin();
        if (CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) {
            requestEmbeddingPluginAdd(convert(Lists.newArrayList(new ChatPlugin[]{plugin})));
        }
    }

    @EventListener
    public void updatePlugin(PluginUpdateEvent pluginUpdateEvent) {
        ChatPlugin oldPlugin = pluginUpdateEvent.getOldPlugin();
        ChatPlugin newPlugin = pluginUpdateEvent.getNewPlugin();
        if (CollectionUtils.isNotEmpty(oldPlugin.getExampleQuestionList())) {
            requestEmbeddingPluginDelete(getEmbeddingId(Lists.newArrayList(new ChatPlugin[]{oldPlugin})));
        }
        if (CollectionUtils.isNotEmpty(newPlugin.getExampleQuestionList())) {
            requestEmbeddingPluginAdd(convert(Lists.newArrayList(new ChatPlugin[]{newPlugin})));
        }
    }

    @EventListener
    public void delPlugin(PluginDelEvent pluginDelEvent) {
        ChatPlugin plugin = pluginDelEvent.getPlugin();
        if (CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) {
            requestEmbeddingPluginDelete(getEmbeddingId(Lists.newArrayList(new ChatPlugin[]{plugin})));
        }
    }

    public void requestEmbeddingPluginDelete(Set<String> set) {
        if (CollectionUtils.isEmpty(set)) {
            return;
        }
        String presetCollection = this.embeddingConfig.getPresetCollection();
        ArrayList arrayList = new ArrayList();
        for (String str : set) {
            TextSegment from = TextSegment.from("");
            TextSegmentConvert.addQueryId(from, str);
            arrayList.add(from);
        }
        this.embeddingService.deleteQuery(presetCollection, arrayList);
    }

    public void requestEmbeddingPluginAdd(List<TextSegment> list) {
        if (CollectionUtils.isEmpty(list)) {
            return;
        }
        this.embeddingService.addQuery(this.embeddingConfig.getPresetCollection(), list);
    }

    public RetrieveQueryResult recognize(String str) {
        List retrieveQuery = this.embeddingService.retrieveQuery(this.embeddingConfig.getPresetCollection(), RetrieveQuery.builder().queryTextsList(Collections.singletonList(str)).build(), this.embeddingConfig.getNResult());
        if (!CollectionUtils.isNotEmpty(retrieveQuery)) {
            throw new RuntimeException("get embedding result failed");
        }
        Iterator it = retrieveQuery.iterator();
        while (it.hasNext()) {
            for (Retrieval retrieval : ((RetrieveQueryResult) it.next()).getRetrieval()) {
                retrieval.setId(getPluginIdFromEmbeddingId(retrieval.getId()));
            }
        }
        return (RetrieveQueryResult) retrieveQuery.get(0);
    }

    public List<TextSegment> convert(List<ChatPlugin> list) {
        ArrayList newArrayList = Lists.newArrayList();
        for (ChatPlugin chatPlugin : list) {
            int i = 0;
            Iterator<String> it = chatPlugin.getExampleQuestionList().iterator();
            while (it.hasNext()) {
                TextSegment from = TextSegment.from(it.next());
                TextSegmentConvert.addQueryId(from, generateUniqueEmbeddingId(i, chatPlugin.getId()));
                newArrayList.add(from);
                i++;
            }
        }
        return newArrayList;
    }

    private Set<String> getEmbeddingId(List<ChatPlugin> list) {
        HashSet hashSet = new HashSet();
        for (TextSegment textSegment : convert(list)) {
            TextSegmentConvert.addQueryId(textSegment, TextSegmentConvert.getQueryId(textSegment));
        }
        return hashSet;
    }

    private String generateUniqueEmbeddingId(int i, Long l) {
        return i < 10 ? String.format("%s00%s", l, Integer.valueOf(i)) : String.format("%s0%s", l, Integer.valueOf(i));
    }

    private String getPluginIdFromEmbeddingId(String str) {
        return String.valueOf(Integer.parseInt(str) / 1000);
    }

    public static Pair<Boolean, Set<Long>> resolve(ChatPlugin chatPlugin, ChatParseContext chatParseContext) {
        SchemaMapInfo mapInfo = chatParseContext.getMapInfo();
        Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(chatPlugin, chatParseContext);
        if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !chatPlugin.isContainsAllDataSet()) {
            return Pair.of(false, Sets.newHashSet());
        }
        List<ParamOption> semanticOption = getSemanticOption(chatPlugin);
        if (CollectionUtils.isEmpty(semanticOption)) {
            return Pair.of(true, pluginMatchedDataSet);
        }
        HashSet newHashSet = Sets.newHashSet();
        Map map = (Map) semanticOption.stream().collect(Collectors.groupingBy((v0) -> {
            return v0.getDataSetId();
        }));
        for (Long l : map.keySet()) {
            List list = (List) map.get(l);
            if (CollectionUtils.isEmpty(list)) {
                newHashSet.add(l);
            } else {
                boolean z = true;
                Iterator it = list.iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    ParamOption paramOption = (ParamOption) it.next();
                    Set<Long> schemaElementMatch = getSchemaElementMatch(l, mapInfo);
                    if (CollectionUtils.isEmpty(schemaElementMatch)) {
                        z = false;
                        break;
                    }
                    if (!schemaElementMatch.contains(paramOption.getElementId())) {
                        z = false;
                        break;
                    }
                }
                if (z) {
                    newHashSet.add(l);
                }
            }
        }
        return CollectionUtils.isEmpty(newHashSet) ? Pair.of(false, Sets.newHashSet()) : Pair.of(true, newHashSet);
    }

    private static Set<Long> getSchemaElementMatch(Long l, SchemaMapInfo schemaMapInfo) {
        List matchedElements = schemaMapInfo.getMatchedElements(l);
        return org.springframework.util.CollectionUtils.isEmpty(matchedElements) ? Sets.newHashSet() : (Set) matchedElements.stream().filter(schemaElementMatch -> {
            return SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()) || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType());
        }).map((v0) -> {
            return v0.getElement();
        }).map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
    }

    private static List<ParamOption> getSemanticOption(ChatPlugin chatPlugin) {
        WebBase webBase = (WebBase) JSONObject.parseObject(chatPlugin.getConfig(), WebBase.class);
        if (Objects.isNull(webBase)) {
            return null;
        }
        List<ParamOption> paramOptions = webBase.getParamOptions();
        return CollectionUtils.isEmpty(paramOptions) ? Lists.newArrayList() : (List) paramOptions.stream().filter(paramOption -> {
            return ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType());
        }).collect(Collectors.toList());
    }

    private static Set<Long> getPluginMatchedDataSet(ChatPlugin chatPlugin, ChatParseContext chatParseContext) {
        Set matchedDataSetInfos = chatParseContext.getMapInfo().getMatchedDataSetInfos();
        if (chatPlugin.isContainsAllDataSet()) {
            return Sets.newHashSet(new Long[]{chatPlugin.getDefaultMode()});
        }
        List<Long> dataSetList = chatPlugin.getDataSetList();
        HashSet newHashSet = Sets.newHashSet();
        for (Long l : dataSetList) {
            if (matchedDataSetInfos.contains(l)) {
                newHashSet.add(l);
            }
        }
        return newHashSet;
    }
}
