package com.tencent.supersonic.chat.server.parser;

import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:com/tencent/supersonic/chat/server/parser/NL2SQLParser.class */
public class NL2SQLParser implements ChatParser {
    private static final Logger log = LoggerFactory.getLogger(NL2SQLParser.class);
    private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
    private static final String REWRITE_INSTRUCTION = "#Role: You are a data product manager experienced in data requirements.\n#Task: Your will be provided with current and history questions asked by a user,along with their mapped schema elements(metric, dimension and value),please try understanding the semantics and rewrite a question.\n#Rules: 1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. 2.ONLY respond with the rewritten question.\n#Current Question: %s\n#Current Mapped Schema: %s\n#History Question: %s\n#History Mapped Schema: %s\n#History SQL: %s\n#Rewritten Question: ";

    /* loaded from: input_file:com/tencent/supersonic/chat/server/parser/NL2SQLParser$RewriteContext.class */
    public static class RewriteContext {
        private String curtQuestion;
        private String histQuestion;
        private String curtSchema;
        private String histSchema;
        private String histSQL;
        private LLMConfig llmConfig;

        /* loaded from: input_file:com/tencent/supersonic/chat/server/parser/NL2SQLParser$RewriteContext$RewriteContextBuilder.class */
        public static class RewriteContextBuilder {
            private String curtQuestion;
            private String histQuestion;
            private String curtSchema;
            private String histSchema;
            private String histSQL;
            private LLMConfig llmConfig;

            RewriteContextBuilder() {
            }

            public RewriteContextBuilder curtQuestion(String str) {
                this.curtQuestion = str;
                return this;
            }

            public RewriteContextBuilder histQuestion(String str) {
                this.histQuestion = str;
                return this;
            }

            public RewriteContextBuilder curtSchema(String str) {
                this.curtSchema = str;
                return this;
            }

            public RewriteContextBuilder histSchema(String str) {
                this.histSchema = str;
                return this;
            }

            public RewriteContextBuilder histSQL(String str) {
                this.histSQL = str;
                return this;
            }

            public RewriteContextBuilder llmConfig(LLMConfig lLMConfig) {
                this.llmConfig = lLMConfig;
                return this;
            }

            public RewriteContext build() {
                return new RewriteContext(this.curtQuestion, this.histQuestion, this.curtSchema, this.histSchema, this.histSQL, this.llmConfig);
            }

            public String toString() {
                return "NL2SQLParser.RewriteContext.RewriteContextBuilder(curtQuestion=" + this.curtQuestion + ", histQuestion=" + this.histQuestion + ", curtSchema=" + this.curtSchema + ", histSchema=" + this.histSchema + ", histSQL=" + this.histSQL + ", llmConfig=" + this.llmConfig + ")";
            }
        }

        RewriteContext(String str, String str2, String str3, String str4, String str5, LLMConfig lLMConfig) {
            this.curtQuestion = str;
            this.histQuestion = str2;
            this.curtSchema = str3;
            this.histSchema = str4;
            this.histSQL = str5;
            this.llmConfig = lLMConfig;
        }

        public static RewriteContextBuilder builder() {
            return new RewriteContextBuilder();
        }

        public String getCurtQuestion() {
            return this.curtQuestion;
        }

        public String getHistQuestion() {
            return this.histQuestion;
        }

        public String getCurtSchema() {
            return this.curtSchema;
        }

        public String getHistSchema() {
            return this.histSchema;
        }

        public String getHistSQL() {
            return this.histSQL;
        }

        public LLMConfig getLlmConfig() {
            return this.llmConfig;
        }

        public void setCurtQuestion(String str) {
            this.curtQuestion = str;
        }

        public void setHistQuestion(String str) {
            this.histQuestion = str;
        }

        public void setCurtSchema(String str) {
            this.curtSchema = str;
        }

        public void setHistSchema(String str) {
            this.histSchema = str;
        }

        public void setHistSQL(String str) {
            this.histSQL = str;
        }

        public void setLlmConfig(LLMConfig lLMConfig) {
            this.llmConfig = lLMConfig;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof RewriteContext)) {
                return false;
            }
            RewriteContext rewriteContext = (RewriteContext) obj;
            if (!rewriteContext.canEqual(this)) {
                return false;
            }
            String curtQuestion = getCurtQuestion();
            String curtQuestion2 = rewriteContext.getCurtQuestion();
            if (curtQuestion == null) {
                if (curtQuestion2 != null) {
                    return false;
                }
            } else if (!curtQuestion.equals(curtQuestion2)) {
                return false;
            }
            String histQuestion = getHistQuestion();
            String histQuestion2 = rewriteContext.getHistQuestion();
            if (histQuestion == null) {
                if (histQuestion2 != null) {
                    return false;
                }
            } else if (!histQuestion.equals(histQuestion2)) {
                return false;
            }
            String curtSchema = getCurtSchema();
            String curtSchema2 = rewriteContext.getCurtSchema();
            if (curtSchema == null) {
                if (curtSchema2 != null) {
                    return false;
                }
            } else if (!curtSchema.equals(curtSchema2)) {
                return false;
            }
            String histSchema = getHistSchema();
            String histSchema2 = rewriteContext.getHistSchema();
            if (histSchema == null) {
                if (histSchema2 != null) {
                    return false;
                }
            } else if (!histSchema.equals(histSchema2)) {
                return false;
            }
            String histSQL = getHistSQL();
            String histSQL2 = rewriteContext.getHistSQL();
            if (histSQL == null) {
                if (histSQL2 != null) {
                    return false;
                }
            } else if (!histSQL.equals(histSQL2)) {
                return false;
            }
            LLMConfig llmConfig = getLlmConfig();
            LLMConfig llmConfig2 = rewriteContext.getLlmConfig();
            return llmConfig == null ? llmConfig2 == null : llmConfig.equals(llmConfig2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof RewriteContext;
        }

        public int hashCode() {
            String curtQuestion = getCurtQuestion();
            int hashCode = (1 * 59) + (curtQuestion == null ? 43 : curtQuestion.hashCode());
            String histQuestion = getHistQuestion();
            int hashCode2 = (hashCode * 59) + (histQuestion == null ? 43 : histQuestion.hashCode());
            String curtSchema = getCurtSchema();
            int hashCode3 = (hashCode2 * 59) + (curtSchema == null ? 43 : curtSchema.hashCode());
            String histSchema = getHistSchema();
            int hashCode4 = (hashCode3 * 59) + (histSchema == null ? 43 : histSchema.hashCode());
            String histSQL = getHistSQL();
            int hashCode5 = (hashCode4 * 59) + (histSQL == null ? 43 : histSQL.hashCode());
            LLMConfig llmConfig = getLlmConfig();
            return (hashCode5 * 59) + (llmConfig == null ? 43 : llmConfig.hashCode());
        }

        public String toString() {
            return "NL2SQLParser.RewriteContext(curtQuestion=" + getCurtQuestion() + ", histQuestion=" + getHistQuestion() + ", curtSchema=" + getCurtSchema() + ", histSchema=" + getHistSchema() + ", histSQL=" + getHistSQL() + ", llmConfig=" + getLlmConfig() + ")";
        }
    }

    @Override // com.tencent.supersonic.chat.server.parser.ChatParser
    public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
        if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
            return;
        }
        processMultiTurn(chatParseContext);
        QueryReq buildText2SqlQueryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
        addExemplars(chatParseContext.getAgent().getId(), buildText2SqlQueryReq);
        ParseResp performParsing = ((ChatQueryService) ContextUtils.getBean(ChatQueryService.class)).performParsing(buildText2SqlQueryReq);
        if (!ParseResp.ParseState.FAILED.equals(performParsing.getState())) {
            parseResp.getSelectedParses().addAll(performParsing.getSelectedParses());
        }
        parseResp.getParseTimeCost().setSqlTime(performParsing.getParseTimeCost().getSqlTime());
        formatParseResult(parseResp);
    }

    private boolean checkSkip(ParseResp parseResp) {
        Iterator it = parseResp.getSelectedParses().iterator();
        while (it.hasNext()) {
            if (((SemanticParseInfo) it.next()).getScore() >= parseResp.getQueryText().length()) {
                return true;
            }
        }
        return false;
    }

    private void formatParseResult(ParseResp parseResp) {
        Iterator it = parseResp.getSelectedParses().iterator();
        while (it.hasNext()) {
            formatParseInfo((SemanticParseInfo) it.next());
        }
    }

    private void formatParseInfo(SemanticParseInfo semanticParseInfo) {
        if (PluginQueryManager.isPluginQuery(semanticParseInfo.getQueryMode())) {
            return;
        }
        formatNL2SQLParseInfo(semanticParseInfo);
    }

    private void formatNL2SQLParseInfo(SemanticParseInfo semanticParseInfo) {
        StringBuilder sb = new StringBuilder();
        sb.append("**数据集:** ").append(semanticParseInfo.getDataSet().getName()).append(" ");
        semanticParseInfo.getMetrics().stream().findFirst().ifPresent(schemaElement -> {
            sb.append("**指标:** ").append(schemaElement.getName()).append(" ");
        });
        List list = (List) semanticParseInfo.getDimensions().stream().map((v0) -> {
            return v0.getName();
        }).filter((v0) -> {
            return Objects.nonNull(v0);
        }).collect(Collectors.toList());
        if (!CollectionUtils.isEmpty(list)) {
            sb.append("**维度:** ").append(String.join(",", list));
        }
        sb.append("\n\n**筛选条件:** \n");
        if (semanticParseInfo.getDateInfo() != null) {
            sb.append("**数据时间:** ").append(semanticParseInfo.getDateInfo().getStartDate()).append("~").append(semanticParseInfo.getDateInfo().getEndDate()).append(" ");
        }
        if (!CollectionUtils.isEmpty(semanticParseInfo.getDimensionFilters()) || CollectionUtils.isEmpty(semanticParseInfo.getMetricFilters())) {
            Set<QueryFilter> dimensionFilters = semanticParseInfo.getDimensionFilters();
            dimensionFilters.addAll(semanticParseInfo.getMetricFilters());
            for (QueryFilter queryFilter : dimensionFilters) {
                sb.append("**").append(queryFilter.getName()).append("**").append(" ").append(queryFilter.getOperator().getValue()).append(" ").append(queryFilter.getValue()).append(" ");
            }
        }
        semanticParseInfo.setTextInfo(sb.toString());
    }

    private void processMultiTurn(ChatParseContext chatParseContext) {
        ParserConfig parserConfig = (ParserConfig) ContextUtils.getBean(ParserConfig.class);
        MultiTurnConfig multiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
        if (Boolean.TRUE.equals(Boolean.valueOf(multiTurnConfig != null ? multiTurnConfig.isEnableMultiTurn() : Boolean.valueOf(parserConfig.getParameterValue(ParserConfig.PARSER_MULTI_TURN_ENABLE)).booleanValue()))) {
            ChatQueryService chatQueryService = (ChatQueryService) ContextUtils.getBean(ChatQueryService.class);
            QueryReq buildText2SqlQueryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
            MapResp performMapping = chatQueryService.performMapping(buildText2SqlQueryReq);
            List<ParseResp> historyParseResult = getHistoryParseResult(chatParseContext.getChatId().intValue(), 1);
            if (historyParseResult.size() == 0) {
                return;
            }
            ParseResp parseResp = historyParseResult.get(0);
            String rewriteQuery = rewriteQuery(RewriteContext.builder().curtQuestion(performMapping.getQueryText()).histQuestion(parseResp.getQueryText()).curtSchema(generateSchemaPrompt(performMapping.getMapInfo().getMatchedElements(((SemanticParseInfo) parseResp.getSelectedParses().get(0)).getDataSetId()))).histSchema(generateSchemaPrompt(((SemanticParseInfo) parseResp.getSelectedParses().get(0)).getElementMatches())).histSQL(((SemanticParseInfo) parseResp.getSelectedParses().get(0)).getSqlInfo().getCorrectS2SQL()).llmConfig(buildText2SqlQueryReq.getLlmConfig()).build());
            chatParseContext.setQueryText(rewriteQuery);
            log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", new Object[]{parseResp.getQueryText(), performMapping.getQueryText(), rewriteQuery});
        }
    }

    private String rewriteQuery(RewriteContext rewriteContext) {
        String format = String.format(REWRITE_INSTRUCTION, rewriteContext.getCurtQuestion(), rewriteContext.getCurtSchema(), rewriteContext.getHistQuestion(), rewriteContext.getHistSchema(), rewriteContext.getHistSQL());
        Prompt apply = PromptTemplate.from(format).apply(Collections.EMPTY_MAP);
        keyPipelineLog.info("NL2SQLParser reqPrompt:{}", format);
        Response generate = ChatLanguageModelProvider.provide(rewriteContext.getLlmConfig()).generate(new ChatMessage[]{apply.toUserMessage()});
        keyPipelineLog.info("NL2SQLParser modelResp:{}", ((AiMessage) generate.content()).text());
        return ((AiMessage) generate.content()).text();
    }

    private String generateSchemaPrompt(List<SchemaElementMatch> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (SchemaElementMatch schemaElementMatch : list) {
            if (schemaElementMatch.getElement().getType().equals(SchemaElementType.METRIC)) {
                arrayList.add(schemaElementMatch.getWord());
            } else if (schemaElementMatch.getElement().getType().equals(SchemaElementType.DIMENSION)) {
                arrayList2.add(schemaElementMatch.getWord());
            } else if (schemaElementMatch.getElement().getType().equals(SchemaElementType.VALUE)) {
                arrayList3.add(schemaElementMatch.getWord());
            }
        }
        return String.format("'metrics:':[%s]", String.join(",", arrayList)) + "," + String.format("'dimensions:':[%s]", String.join(",", arrayList2)) + "," + String.format("'values:':[%s]", String.join(",", arrayList3));
    }

    private List<ParseResp> getHistoryParseResult(int i, int i2) {
        List list = (List) ((ChatQueryRepository) ContextUtils.getBean(ChatQueryRepository.class)).getContextualParseInfo(Integer.valueOf(i)).stream().filter(parseResp -> {
            return parseResp.getState() != ParseResp.ParseState.FAILED;
        }).collect(Collectors.toList());
        List<ParseResp> subList = list.subList(0, Math.min(i2, list.size()));
        Collections.reverse(subList);
        return subList;
    }

    private void addExemplars(Integer num, QueryReq queryReq) {
        queryReq.getExemplars().addAll(((ExemplarServiceImpl) ContextUtils.getBean(ExemplarServiceImpl.class)).recallExemplars(((EmbeddingConfig) ContextUtils.getBean(EmbeddingConfig.class)).getMemoryCollectionName(num), queryReq.getQueryText(), 5));
    }
}
