/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.chat.server.parser;

import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
import com.tencent.supersonic.chat.server.parser.ParserConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.provider.ModelProvider;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

public class NL2SQLParser
implements ChatQueryParser {
    private static final Logger log = LoggerFactory.getLogger(NL2SQLParser.class);
    private static final Logger keyPipelineLog = LoggerFactory.getLogger((String)"keyPipeline");
    private static final String REWRITE_USER_QUESTION_INSTRUCTION = "#Role: You are a data product manager experienced in data requirements.#Task: Your will be provided with current and history questions asked by a user,along with their mapped schema elements(metric, dimension and value),please try understanding the semantics and rewrite a question.#Rules: 1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges.2.ONLY respond with the rewritten question.#Current Question: {{current_question}}#Current Mapped Schema: {{current_schema}}#History Question: {{history_question}}#History Mapped Schema: {{history_schema}}#History SQL: {{history_sql}}#Rewritten Question: ";
    private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = "#Role: You are a data business partner who closely interacts with business people.\n#Task: Your will be provided with user input, system output and some examples, please respond shortly to teach user how to ask the right question, using `Examples` as references.#Rules: ALWAYS use the same language as the `Input`.\n#Input: {{user_question}}\n#Output: {{system_message}}\n#Examples: {{examples}}\n#Response: ";

    @Override
    public void parse(ParseContext parseContext, ParseResp parseResp) {
        if (!parseContext.enableNL2SQL() || this.checkSkip(parseResp)) {
            return;
        }
        ChatContextService chatContextService = (ChatContextService)ContextUtils.getBean(ChatContextService.class);
        ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
        ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel((ChatModelConfig)parseContext.getAgent().getModelConfig());
        this.processMultiTurn(chatLanguageModel, parseContext);
        QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
        this.addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
        ChatLayerService chatLayerService = (ChatLayerService)ContextUtils.getBean(ChatLayerService.class);
        ParseResp text2SqlParseResp = chatLayerService.performParsing(queryNLReq);
        if (ParseResp.ParseState.COMPLETED.equals((Object)text2SqlParseResp.getState())) {
            parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
        } else {
            parseResp.setErrorMsg(this.rewriteErrorMessage(chatLanguageModel, parseContext.getQueryText(), text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(), parseContext.getAgent().getExamples()));
        }
        parseResp.setState(text2SqlParseResp.getState());
        parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
        this.formatParseResult(parseResp);
    }

    private boolean checkSkip(ParseResp parseResp) {
        List selectedParses = parseResp.getSelectedParses();
        for (SemanticParseInfo semanticParseInfo : selectedParses) {
            if (!(semanticParseInfo.getScore() >= (double)parseResp.getQueryText().length())) continue;
            return true;
        }
        return false;
    }

    private void formatParseResult(ParseResp parseResp) {
        List selectedParses = parseResp.getSelectedParses();
        for (SemanticParseInfo parseInfo : selectedParses) {
            this.formatParseInfo(parseInfo);
        }
    }

    private void formatParseInfo(SemanticParseInfo parseInfo) {
        if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
            this.formatNL2SQLParseInfo(parseInfo);
        }
    }

    private void formatNL2SQLParseInfo(SemanticParseInfo parseInfo) {
        StringBuilder textBuilder = new StringBuilder();
        textBuilder.append("**\u6570\u636e\u96c6:** ").append(parseInfo.getDataSet().getName()).append(" ");
        Optional metric = parseInfo.getMetrics().stream().findFirst();
        metric.ifPresent(schemaElement -> textBuilder.append("**\u6307\u6807:** ").append(schemaElement.getName()).append(" "));
        List dimensionNames = parseInfo.getDimensions().stream().map(SchemaElement::getName).filter(Objects::nonNull).collect(Collectors.toList());
        if (!CollectionUtils.isEmpty(dimensionNames)) {
            textBuilder.append("**\u7ef4\u5ea6:** ").append(String.join((CharSequence)",", dimensionNames));
        }
        textBuilder.append("\n\n**\u7b5b\u9009\u6761\u4ef6:** \n");
        if (parseInfo.getDateInfo() != null) {
            textBuilder.append("**\u6570\u636e\u65f6\u95f4:** ").append(parseInfo.getDateInfo().getStartDate()).append("~").append(parseInfo.getDateInfo().getEndDate()).append(" ");
        }
        if (!CollectionUtils.isEmpty((Collection)parseInfo.getDimensionFilters()) || CollectionUtils.isEmpty((Collection)parseInfo.getMetricFilters())) {
            Set queryFilters = parseInfo.getDimensionFilters();
            queryFilters.addAll(parseInfo.getMetricFilters());
            for (QueryFilter queryFilter : queryFilters) {
                textBuilder.append("**").append(queryFilter.getName()).append("**").append(" ").append(queryFilter.getOperator().getValue()).append(" ").append(queryFilter.getValue()).append(" ");
            }
        }
        parseInfo.setTextInfo(textBuilder.toString());
    }

    private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
        ParserConfig parserConfig = (ParserConfig)((Object)ContextUtils.getBean(ParserConfig.class));
        MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
        Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(ParserConfig.PARSER_MULTI_TURN_ENABLE));
        Boolean multiTurnConfig = agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig.booleanValue();
        if (!Boolean.TRUE.equals(multiTurnConfig)) {
            return;
        }
        ChatLayerService chatLayerService = (ChatLayerService)ContextUtils.getBean(ChatLayerService.class);
        QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
        MapResp currentMapResult = chatLayerService.performMapping(queryNLReq);
        List<ParseResp> historyParseResults = this.getHistoryParseResult(parseContext.getChatId(), 1);
        if (historyParseResults.size() == 0) {
            return;
        }
        ParseResp lastParseResult = historyParseResults.get(0);
        Long dataId = ((SemanticParseInfo)lastParseResult.getSelectedParses().get(0)).getDataSetId();
        String curtMapStr = this.generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
        String histMapStr = this.generateSchemaPrompt(((SemanticParseInfo)lastParseResult.getSelectedParses().get(0)).getElementMatches());
        String histSQL = ((SemanticParseInfo)lastParseResult.getSelectedParses().get(0)).getSqlInfo().getCorrectedS2SQL();
        HashMap<String, String> variables = new HashMap<String, String>();
        variables.put("current_question", currentMapResult.getQueryText());
        variables.put("current_schema", curtMapStr);
        variables.put("history_question", lastParseResult.getQueryText());
        variables.put("history_schema", histMapStr);
        variables.put("history_sql", histSQL);
        Prompt prompt = PromptTemplate.from((String)REWRITE_USER_QUESTION_INSTRUCTION).apply(variables);
        keyPipelineLog.info("NL2SQLParser reqPrompt:{}", (Object)prompt.text());
        Response response = chatLanguageModel.generate(new ChatMessage[]{prompt.toUserMessage()});
        String rewrittenQuery = ((AiMessage)response.content()).text();
        keyPipelineLog.info("NL2SQLParser modelResp:{}", (Object)rewrittenQuery);
        parseContext.setQueryText(rewrittenQuery);
        log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", new Object[]{lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery});
    }

    private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion, String errMsg, List<SqlExemplar> similarExemplars, List<String> agentExamples) {
        HashMap<String, CharSequence> variables = new HashMap<String, CharSequence>();
        variables.put("user_question", userQuestion);
        variables.put("system_message", errMsg);
        StringBuilder exampleStr = new StringBuilder();
        if (similarExemplars.size() > 0) {
            similarExemplars.stream().forEach(e -> exampleStr.append(String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
        } else {
            agentExamples.stream().forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e)));
        }
        variables.put("examples", exampleStr);
        Prompt prompt = PromptTemplate.from((String)REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
        keyPipelineLog.info("NL2SQLParser reqPrompt:{}", (Object)prompt.text());
        Response response = chatLanguageModel.generate(new ChatMessage[]{prompt.toUserMessage()});
        String result = ((AiMessage)response.content()).text();
        keyPipelineLog.info("NL2SQLParser modelResp:{}", (Object)result);
        return ((AiMessage)response.content()).text();
    }

    private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
        ArrayList<String> metrics = new ArrayList<String>();
        ArrayList<String> dimensions = new ArrayList<String>();
        ArrayList<String> values = new ArrayList<String>();
        for (SchemaElementMatch match : elementMatches) {
            if (match.getElement().getType().equals((Object)SchemaElementType.METRIC)) {
                metrics.add(match.getWord());
                continue;
            }
            if (match.getElement().getType().equals((Object)SchemaElementType.DIMENSION)) {
                dimensions.add(match.getWord());
                continue;
            }
            if (!match.getElement().getType().equals((Object)SchemaElementType.VALUE)) continue;
            values.add(match.getWord());
        }
        StringBuilder prompt = new StringBuilder();
        prompt.append(String.format("'metrics:':[%s]", String.join((CharSequence)",", metrics)));
        prompt.append(",");
        prompt.append(String.format("'dimensions:':[%s]", String.join((CharSequence)",", dimensions)));
        prompt.append(",");
        prompt.append(String.format("'values:':[%s]", String.join((CharSequence)",", values)));
        return prompt.toString();
    }

    private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
        ChatQueryRepository chatQueryRepository = (ChatQueryRepository)ContextUtils.getBean(ChatQueryRepository.class);
        List contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId).stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
        List<ParseResp> contextualList = contextualParseInfoList.subList(0, Math.min(multiNum, contextualParseInfoList.size()));
        Collections.reverse(contextualList);
        return contextualList;
    }

    private void addDynamicExemplars(Integer agentId, QueryNLReq queryNLReq) {
        ExemplarServiceImpl exemplarManager = (ExemplarServiceImpl)ContextUtils.getBean(ExemplarServiceImpl.class);
        EmbeddingConfig embeddingConfig = (EmbeddingConfig)ContextUtils.getBean(EmbeddingConfig.class);
        String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
        List exemplars = exemplarManager.recallExemplars(memoryCollectionName, queryNLReq.getQueryText(), 5);
        queryNLReq.getDynamicExemplars().addAll(exemplars);
    }
}

