package com.tencent.supersonic.chat.server.parser;

import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.provider.ModelProvider;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:com/tencent/supersonic/chat/server/parser/NL2SQLParser.class */
public class NL2SQLParser implements ChatQueryParser {
    private static final Logger log = LoggerFactory.getLogger(NL2SQLParser.class);
    private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
    private static final String REWRITE_USER_QUESTION_INSTRUCTION = "#Role: You are a data product manager experienced in data requirements.#Task: Your will be provided with current and history questions asked by a user,along with their mapped schema elements(metric, dimension and value),please try understanding the semantics and rewrite a question.#Rules: 1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges.2.ONLY respond with the rewritten question.#Current Question: {{current_question}}#Current Mapped Schema: {{current_schema}}#History Question: {{history_question}}#History Mapped Schema: {{history_schema}}#History SQL: {{history_sql}}#Rewritten Question: ";
    private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = "#Role: You are a data business partner who closely interacts with business people.\n#Task: Your will be provided with user input, system output and some examples, please respond shortly to teach user how to ask the right question, using `Examples` as references.#Rules: ALWAYS use the same language as the `Input`.\n#Input: {{user_question}}\n#Output: {{system_message}}\n#Examples: {{examples}}\n#Response: ";

    @Override // com.tencent.supersonic.chat.server.parser.ChatQueryParser
    public void parse(ParseContext parseContext, ParseResp parseResp) {
        if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) {
            return;
        }
        ChatContext orCreateContext = ((ChatContextService) ContextUtils.getBean(ChatContextService.class)).getOrCreateContext(parseContext.getChatId());
        ChatLanguageModel chatModel = ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
        processMultiTurn(chatModel, parseContext);
        QueryNLReq buildText2SqlQueryReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, orCreateContext);
        addDynamicExemplars(parseContext.getAgent().getId(), buildText2SqlQueryReq);
        ParseResp performParsing = ((ChatLayerService) ContextUtils.getBean(ChatLayerService.class)).performParsing(buildText2SqlQueryReq);
        if (ParseResp.ParseState.COMPLETED.equals(performParsing.getState())) {
            parseResp.getSelectedParses().addAll(performParsing.getSelectedParses());
        } else {
            parseResp.setErrorMsg(rewriteErrorMessage(chatModel, parseContext.getQueryText(), performParsing.getErrorMsg(), buildText2SqlQueryReq.getDynamicExemplars(), parseContext.getAgent().getExamples()));
        }
        parseResp.setState(performParsing.getState());
        parseResp.getParseTimeCost().setSqlTime(performParsing.getParseTimeCost().getSqlTime());
        formatParseResult(parseResp);
    }

    private boolean checkSkip(ParseResp parseResp) {
        Iterator it = parseResp.getSelectedParses().iterator();
        while (it.hasNext()) {
            if (((SemanticParseInfo) it.next()).getScore() >= parseResp.getQueryText().length()) {
                return true;
            }
        }
        return false;
    }

    private void formatParseResult(ParseResp parseResp) {
        Iterator it = parseResp.getSelectedParses().iterator();
        while (it.hasNext()) {
            formatParseInfo((SemanticParseInfo) it.next());
        }
    }

    private void formatParseInfo(SemanticParseInfo semanticParseInfo) {
        if (PluginQueryManager.isPluginQuery(semanticParseInfo.getQueryMode())) {
            return;
        }
        formatNL2SQLParseInfo(semanticParseInfo);
    }

    private void formatNL2SQLParseInfo(SemanticParseInfo semanticParseInfo) {
        StringBuilder sb = new StringBuilder();
        sb.append("**数据集:** ").append(semanticParseInfo.getDataSet().getName()).append(" ");
        semanticParseInfo.getMetrics().stream().findFirst().ifPresent(schemaElement -> {
            sb.append("**指标:** ").append(schemaElement.getName()).append(" ");
        });
        List list = (List) semanticParseInfo.getDimensions().stream().map((v0) -> {
            return v0.getName();
        }).filter((v0) -> {
            return Objects.nonNull(v0);
        }).collect(Collectors.toList());
        if (!CollectionUtils.isEmpty(list)) {
            sb.append("**维度:** ").append(String.join(",", list));
        }
        sb.append("\n\n**筛选条件:** \n");
        if (semanticParseInfo.getDateInfo() != null) {
            sb.append("**数据时间:** ").append(semanticParseInfo.getDateInfo().getStartDate()).append("~").append(semanticParseInfo.getDateInfo().getEndDate()).append(" ");
        }
        if (!CollectionUtils.isEmpty(semanticParseInfo.getDimensionFilters()) || CollectionUtils.isEmpty(semanticParseInfo.getMetricFilters())) {
            Set<QueryFilter> dimensionFilters = semanticParseInfo.getDimensionFilters();
            dimensionFilters.addAll(semanticParseInfo.getMetricFilters());
            for (QueryFilter queryFilter : dimensionFilters) {
                sb.append("**").append(queryFilter.getName()).append("**").append(" ").append(queryFilter.getOperator().getValue()).append(" ").append(queryFilter.getValue()).append(" ");
            }
        }
        semanticParseInfo.setTextInfo(sb.toString());
    }

    private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
        ParserConfig parserConfig = (ParserConfig) ContextUtils.getBean(ParserConfig.class);
        MultiTurnConfig multiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
        if (Boolean.TRUE.equals(Boolean.valueOf(multiTurnConfig != null ? multiTurnConfig.isEnableMultiTurn() : Boolean.valueOf(parserConfig.getParameterValue(ParserConfig.PARSER_MULTI_TURN_ENABLE)).booleanValue()))) {
            MapResp performMapping = ((ChatLayerService) ContextUtils.getBean(ChatLayerService.class)).performMapping(QueryReqConverter.buildText2SqlQueryReq(parseContext));
            List<ParseResp> historyParseResult = getHistoryParseResult(parseContext.getChatId().intValue(), 1);
            if (historyParseResult.size() == 0) {
                return;
            }
            ParseResp parseResp = historyParseResult.get(0);
            String generateSchemaPrompt = generateSchemaPrompt(performMapping.getMapInfo().getMatchedElements(((SemanticParseInfo) parseResp.getSelectedParses().get(0)).getDataSetId()));
            String generateSchemaPrompt2 = generateSchemaPrompt(((SemanticParseInfo) parseResp.getSelectedParses().get(0)).getElementMatches());
            String correctedS2SQL = ((SemanticParseInfo) parseResp.getSelectedParses().get(0)).getSqlInfo().getCorrectedS2SQL();
            HashMap hashMap = new HashMap();
            hashMap.put("current_question", performMapping.getQueryText());
            hashMap.put("current_schema", generateSchemaPrompt);
            hashMap.put("history_question", parseResp.getQueryText());
            hashMap.put("history_schema", generateSchemaPrompt2);
            hashMap.put("history_sql", correctedS2SQL);
            Prompt apply = PromptTemplate.from(REWRITE_USER_QUESTION_INSTRUCTION).apply(hashMap);
            keyPipelineLog.info("NL2SQLParser reqPrompt:{}", apply.text());
            String text = ((AiMessage) chatLanguageModel.generate(new ChatMessage[]{apply.toUserMessage()}).content()).text();
            keyPipelineLog.info("NL2SQLParser modelResp:{}", text);
            parseContext.setQueryText(text);
            log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", new Object[]{parseResp.getQueryText(), performMapping.getQueryText(), text});
        }
    }

    private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String str, String str2, List<SqlExemplar> list, List<String> list2) {
        HashMap hashMap = new HashMap();
        hashMap.put("user_question", str);
        hashMap.put("system_message", str2);
        StringBuilder sb = new StringBuilder();
        if (list.size() > 0) {
            list.stream().forEach(sqlExemplar -> {
                sb.append(String.format("<Question:{%s},Schema:{%s}> ", sqlExemplar.getQuestion(), sqlExemplar.getDbSchema()));
            });
        } else {
            list2.stream().forEach(str3 -> {
                sb.append(String.format("<Question:{%s}> ", str3));
            });
        }
        hashMap.put("examples", sb);
        Prompt apply = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(hashMap);
        keyPipelineLog.info("NL2SQLParser reqPrompt:{}", apply.text());
        Response generate = chatLanguageModel.generate(new ChatMessage[]{apply.toUserMessage()});
        keyPipelineLog.info("NL2SQLParser modelResp:{}", ((AiMessage) generate.content()).text());
        return ((AiMessage) generate.content()).text();
    }

    private String generateSchemaPrompt(List<SchemaElementMatch> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (SchemaElementMatch schemaElementMatch : list) {
            if (schemaElementMatch.getElement().getType().equals(SchemaElementType.METRIC)) {
                arrayList.add(schemaElementMatch.getWord());
            } else if (schemaElementMatch.getElement().getType().equals(SchemaElementType.DIMENSION)) {
                arrayList2.add(schemaElementMatch.getWord());
            } else if (schemaElementMatch.getElement().getType().equals(SchemaElementType.VALUE)) {
                arrayList3.add(schemaElementMatch.getWord());
            }
        }
        return String.format("'metrics:':[%s]", String.join(",", arrayList)) + "," + String.format("'dimensions:':[%s]", String.join(",", arrayList2)) + "," + String.format("'values:':[%s]", String.join(",", arrayList3));
    }

    private List<ParseResp> getHistoryParseResult(int i, int i2) {
        List list = (List) ((ChatQueryRepository) ContextUtils.getBean(ChatQueryRepository.class)).getContextualParseInfo(Integer.valueOf(i)).stream().filter(parseResp -> {
            return parseResp.getState() == ParseResp.ParseState.COMPLETED;
        }).collect(Collectors.toList());
        List<ParseResp> subList = list.subList(0, Math.min(i2, list.size()));
        Collections.reverse(subList);
        return subList;
    }

    private void addDynamicExemplars(Integer num, QueryNLReq queryNLReq) {
        queryNLReq.getDynamicExemplars().addAll(((ExemplarServiceImpl) ContextUtils.getBean(ExemplarServiceImpl.class)).recallExemplars(((EmbeddingConfig) ContextUtils.getBean(EmbeddingConfig.class)).getMemoryCollectionName(num), queryNLReq.getQueryText(), 5));
    }
}
