/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.chat.server.parser;

import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.parser.ChatParser;
import com.tencent.supersonic.chat.server.parser.ParserConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

public class NL2SQLParser
implements ChatParser {
    private static final Logger log = LoggerFactory.getLogger(NL2SQLParser.class);
    private static final Logger keyPipelineLog = LoggerFactory.getLogger((String)"keyPipeline");
    private static final String REWRITE_INSTRUCTION = "#Role: You are a data product manager experienced in data requirements.\n#Task: Your will be provided with current and history questions asked by a user,along with their mapped schema elements(metric, dimension and value),please try understanding the semantics and rewrite a question.\n#Rules: 1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. 2.ONLY respond with the rewritten question.\n#Current Question: %s\n#Current Mapped Schema: %s\n#History Question: %s\n#History Mapped Schema: %s\n#History SQL: %s\n#Rewritten Question: ";

    @Override
    public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
        if (!chatParseContext.enableNL2SQL() || this.checkSkip(parseResp)) {
            return;
        }
        this.processMultiTurn(chatParseContext);
        QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
        this.addExemplars(chatParseContext.getAgent().getId(), queryReq);
        ChatQueryService chatQueryService = (ChatQueryService)ContextUtils.getBean(ChatQueryService.class);
        ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
        if (!ParseResp.ParseState.FAILED.equals((Object)text2SqlParseResp.getState())) {
            parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
        }
        parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
        this.formatParseResult(parseResp);
    }

    private boolean checkSkip(ParseResp parseResp) {
        List selectedParses = parseResp.getSelectedParses();
        for (SemanticParseInfo semanticParseInfo : selectedParses) {
            if (!(semanticParseInfo.getScore() >= (double)parseResp.getQueryText().length())) continue;
            return true;
        }
        return false;
    }

    private void formatParseResult(ParseResp parseResp) {
        List selectedParses = parseResp.getSelectedParses();
        for (SemanticParseInfo parseInfo : selectedParses) {
            this.formatParseInfo(parseInfo);
        }
    }

    private void formatParseInfo(SemanticParseInfo parseInfo) {
        if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
            this.formatNL2SQLParseInfo(parseInfo);
        }
    }

    private void formatNL2SQLParseInfo(SemanticParseInfo parseInfo) {
        StringBuilder textBuilder = new StringBuilder();
        textBuilder.append("**\u6570\u636e\u96c6:** ").append(parseInfo.getDataSet().getName()).append(" ");
        Optional metric = parseInfo.getMetrics().stream().findFirst();
        metric.ifPresent(schemaElement -> textBuilder.append("**\u6307\u6807:** ").append(schemaElement.getName()).append(" "));
        List dimensionNames = parseInfo.getDimensions().stream().map(SchemaElement::getName).filter(Objects::nonNull).collect(Collectors.toList());
        if (!CollectionUtils.isEmpty(dimensionNames)) {
            textBuilder.append("**\u7ef4\u5ea6:** ").append(String.join((CharSequence)",", dimensionNames));
        }
        textBuilder.append("\n\n**\u7b5b\u9009\u6761\u4ef6:** \n");
        if (parseInfo.getDateInfo() != null) {
            textBuilder.append("**\u6570\u636e\u65f6\u95f4:** ").append(parseInfo.getDateInfo().getStartDate()).append("~").append(parseInfo.getDateInfo().getEndDate()).append(" ");
        }
        if (!CollectionUtils.isEmpty((Collection)parseInfo.getDimensionFilters()) || CollectionUtils.isEmpty((Collection)parseInfo.getMetricFilters())) {
            Set queryFilters = parseInfo.getDimensionFilters();
            queryFilters.addAll(parseInfo.getMetricFilters());
            for (QueryFilter queryFilter : queryFilters) {
                textBuilder.append("**").append(queryFilter.getName()).append("**").append(" ").append(queryFilter.getOperator().getValue()).append(" ").append(queryFilter.getValue()).append(" ");
            }
        }
        parseInfo.setTextInfo(textBuilder.toString());
    }

    private void processMultiTurn(ChatParseContext chatParseContext) {
        ParserConfig parserConfig = (ParserConfig)((Object)ContextUtils.getBean(ParserConfig.class));
        MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
        Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(ParserConfig.PARSER_MULTI_TURN_ENABLE));
        Boolean multiTurnConfig = agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig.booleanValue();
        if (!Boolean.TRUE.equals(multiTurnConfig)) {
            return;
        }
        ChatQueryService chatQueryService = (ChatQueryService)ContextUtils.getBean(ChatQueryService.class);
        QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
        MapResp currentMapResult = chatQueryService.performMapping(queryReq);
        List<ParseResp> historyParseResults = this.getHistoryParseResult(chatParseContext.getChatId(), 1);
        if (historyParseResults.size() == 0) {
            return;
        }
        ParseResp lastParseResult = historyParseResults.get(0);
        Long dataId = ((SemanticParseInfo)lastParseResult.getSelectedParses().get(0)).getDataSetId();
        String curtMapStr = this.generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
        String histMapStr = this.generateSchemaPrompt(((SemanticParseInfo)lastParseResult.getSelectedParses().get(0)).getElementMatches());
        String histSQL = ((SemanticParseInfo)lastParseResult.getSelectedParses().get(0)).getSqlInfo().getCorrectS2SQL();
        String rewrittenQuery = this.rewriteQuery(RewriteContext.builder().curtQuestion(currentMapResult.getQueryText()).histQuestion(lastParseResult.getQueryText()).curtSchema(curtMapStr).histSchema(histMapStr).histSQL(histSQL).llmConfig(queryReq.getLlmConfig()).build());
        chatParseContext.setQueryText(rewrittenQuery);
        log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", new Object[]{lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery});
    }

    private String rewriteQuery(RewriteContext context) {
        String promptStr = String.format(REWRITE_INSTRUCTION, context.getCurtQuestion(), context.getCurtSchema(), context.getHistQuestion(), context.getHistSchema(), context.getHistSQL());
        Prompt prompt = PromptTemplate.from((String)promptStr).apply(Collections.EMPTY_MAP);
        keyPipelineLog.info("NL2SQLParser reqPrompt:{}", (Object)promptStr);
        ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide((LLMConfig)context.getLlmConfig());
        Response response = chatLanguageModel.generate(new ChatMessage[]{prompt.toUserMessage()});
        String result = ((AiMessage)response.content()).text();
        keyPipelineLog.info("NL2SQLParser modelResp:{}", (Object)result);
        return ((AiMessage)response.content()).text();
    }

    private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
        ArrayList<String> metrics = new ArrayList<String>();
        ArrayList<String> dimensions = new ArrayList<String>();
        ArrayList<String> values = new ArrayList<String>();
        for (SchemaElementMatch match : elementMatches) {
            if (match.getElement().getType().equals((Object)SchemaElementType.METRIC)) {
                metrics.add(match.getWord());
                continue;
            }
            if (match.getElement().getType().equals((Object)SchemaElementType.DIMENSION)) {
                dimensions.add(match.getWord());
                continue;
            }
            if (!match.getElement().getType().equals((Object)SchemaElementType.VALUE)) continue;
            values.add(match.getWord());
        }
        StringBuilder prompt = new StringBuilder();
        prompt.append(String.format("'metrics:':[%s]", String.join((CharSequence)",", metrics)));
        prompt.append(",");
        prompt.append(String.format("'dimensions:':[%s]", String.join((CharSequence)",", dimensions)));
        prompt.append(",");
        prompt.append(String.format("'values:':[%s]", String.join((CharSequence)",", values)));
        return prompt.toString();
    }

    private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
        ChatQueryRepository chatQueryRepository = (ChatQueryRepository)ContextUtils.getBean(ChatQueryRepository.class);
        List contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId).stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
        List<ParseResp> contextualList = contextualParseInfoList.subList(0, Math.min(multiNum, contextualParseInfoList.size()));
        Collections.reverse(contextualList);
        return contextualList;
    }

    private void addExemplars(Integer agentId, QueryReq queryReq) {
        ExemplarServiceImpl exemplarManager = (ExemplarServiceImpl)ContextUtils.getBean(ExemplarServiceImpl.class);
        EmbeddingConfig embeddingConfig = (EmbeddingConfig)ContextUtils.getBean(EmbeddingConfig.class);
        String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
        List exemplars = exemplarManager.recallExemplars(memoryCollectionName, queryReq.getQueryText(), 5);
        queryReq.getExemplars().addAll(exemplars);
    }

    public static class RewriteContext {
        private String curtQuestion;
        private String histQuestion;
        private String curtSchema;
        private String histSchema;
        private String histSQL;
        private LLMConfig llmConfig;

        RewriteContext(String curtQuestion, String histQuestion, String curtSchema, String histSchema, String histSQL, LLMConfig llmConfig) {
            this.curtQuestion = curtQuestion;
            this.histQuestion = histQuestion;
            this.curtSchema = curtSchema;
            this.histSchema = histSchema;
            this.histSQL = histSQL;
            this.llmConfig = llmConfig;
        }

        public static RewriteContextBuilder builder() {
            return new RewriteContextBuilder();
        }

        public String getCurtQuestion() {
            return this.curtQuestion;
        }

        public String getHistQuestion() {
            return this.histQuestion;
        }

        public String getCurtSchema() {
            return this.curtSchema;
        }

        public String getHistSchema() {
            return this.histSchema;
        }

        public String getHistSQL() {
            return this.histSQL;
        }

        public LLMConfig getLlmConfig() {
            return this.llmConfig;
        }

        public void setCurtQuestion(String curtQuestion) {
            this.curtQuestion = curtQuestion;
        }

        public void setHistQuestion(String histQuestion) {
            this.histQuestion = histQuestion;
        }

        public void setCurtSchema(String curtSchema) {
            this.curtSchema = curtSchema;
        }

        public void setHistSchema(String histSchema) {
            this.histSchema = histSchema;
        }

        public void setHistSQL(String histSQL) {
            this.histSQL = histSQL;
        }

        public void setLlmConfig(LLMConfig llmConfig) {
            this.llmConfig = llmConfig;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof RewriteContext)) {
                return false;
            }
            RewriteContext other = (RewriteContext)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$curtQuestion = this.getCurtQuestion();
            String other$curtQuestion = other.getCurtQuestion();
            if (this$curtQuestion == null ? other$curtQuestion != null : !this$curtQuestion.equals(other$curtQuestion)) {
                return false;
            }
            String this$histQuestion = this.getHistQuestion();
            String other$histQuestion = other.getHistQuestion();
            if (this$histQuestion == null ? other$histQuestion != null : !this$histQuestion.equals(other$histQuestion)) {
                return false;
            }
            String this$curtSchema = this.getCurtSchema();
            String other$curtSchema = other.getCurtSchema();
            if (this$curtSchema == null ? other$curtSchema != null : !this$curtSchema.equals(other$curtSchema)) {
                return false;
            }
            String this$histSchema = this.getHistSchema();
            String other$histSchema = other.getHistSchema();
            if (this$histSchema == null ? other$histSchema != null : !this$histSchema.equals(other$histSchema)) {
                return false;
            }
            String this$histSQL = this.getHistSQL();
            String other$histSQL = other.getHistSQL();
            if (this$histSQL == null ? other$histSQL != null : !this$histSQL.equals(other$histSQL)) {
                return false;
            }
            LLMConfig this$llmConfig = this.getLlmConfig();
            LLMConfig other$llmConfig = other.getLlmConfig();
            return !(this$llmConfig == null ? other$llmConfig != null : !this$llmConfig.equals(other$llmConfig));
        }

        protected boolean canEqual(Object other) {
            return other instanceof RewriteContext;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $curtQuestion = this.getCurtQuestion();
            result = result * 59 + ($curtQuestion == null ? 43 : $curtQuestion.hashCode());
            String $histQuestion = this.getHistQuestion();
            result = result * 59 + ($histQuestion == null ? 43 : $histQuestion.hashCode());
            String $curtSchema = this.getCurtSchema();
            result = result * 59 + ($curtSchema == null ? 43 : $curtSchema.hashCode());
            String $histSchema = this.getHistSchema();
            result = result * 59 + ($histSchema == null ? 43 : $histSchema.hashCode());
            String $histSQL = this.getHistSQL();
            result = result * 59 + ($histSQL == null ? 43 : $histSQL.hashCode());
            LLMConfig $llmConfig = this.getLlmConfig();
            result = result * 59 + ($llmConfig == null ? 43 : $llmConfig.hashCode());
            return result;
        }

        public String toString() {
            return "NL2SQLParser.RewriteContext(curtQuestion=" + this.getCurtQuestion() + ", histQuestion=" + this.getHistQuestion() + ", curtSchema=" + this.getCurtSchema() + ", histSchema=" + this.getHistSchema() + ", histSQL=" + this.getHistSQL() + ", llmConfig=" + this.getLlmConfig() + ")";
        }

        public static class RewriteContextBuilder {
            private String curtQuestion;
            private String histQuestion;
            private String curtSchema;
            private String histSchema;
            private String histSQL;
            private LLMConfig llmConfig;

            RewriteContextBuilder() {
            }

            public RewriteContextBuilder curtQuestion(String curtQuestion) {
                this.curtQuestion = curtQuestion;
                return this;
            }

            public RewriteContextBuilder histQuestion(String histQuestion) {
                this.histQuestion = histQuestion;
                return this;
            }

            public RewriteContextBuilder curtSchema(String curtSchema) {
                this.curtSchema = curtSchema;
                return this;
            }

            public RewriteContextBuilder histSchema(String histSchema) {
                this.histSchema = histSchema;
                return this;
            }

            public RewriteContextBuilder histSQL(String histSQL) {
                this.histSQL = histSQL;
                return this;
            }

            public RewriteContextBuilder llmConfig(LLMConfig llmConfig) {
                this.llmConfig = llmConfig;
                return this;
            }

            public RewriteContext build() {
                return new RewriteContext(this.curtQuestion, this.histQuestion, this.curtSchema, this.histSchema, this.histSQL, this.llmConfig);
            }

            public String toString() {
                return "NL2SQLParser.RewriteContext.RewriteContextBuilder(curtQuestion=" + this.curtQuestion + ", histQuestion=" + this.histQuestion + ", curtSchema=" + this.curtSchema + ", histSchema=" + this.histSchema + ", histSQL=" + this.histSQL + ", llmConfig=" + this.llmConfig + ")";
            }
        }
    }
}

