package com.tencent.supersonic.headless.chat.parser.llm;

import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;

@Service
/* loaded from: input_file:com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.class */
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
    private static final Logger log = LoggerFactory.getLogger(OnePassSCSqlGenStrategy.class);
    private static final String INSTRUCTION = "#SQL Syntax: Using SQL syntax for MySQL database.\n#Role: You are a data analyst experienced in SQL languages.\n#Task: You will be provided a natural language question asked by users,please convert it to a SQL query so that relevant data could be returned by executing the SQL query against underlying database.\n#Rules:1.ALWAYS use `数据日期` as the date field.2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator.3.ALWAYS calculate the absolute date range by yourself.4.DO NOT include date filter in the where clause if not explicitly expressed in the question.5.ONLY respond with the converted SQL statement.\n#Exemplars:\n{{exemplar}}#Question:{{question}} #Schema:{{schema}} #SQL:";

    @Override // com.tencent.supersonic.headless.chat.parser.llm.SqlGenStrategy
    public LLMResp generate(LLMReq lLMReq) {
        keyPipelineLog.debug("OnePassSCSqlGenStrategy llmReq:\n{}", lLMReq);
        List<List<SqlExemplar>> fewShotExemplars = this.promptHelper.getFewShotExemplars(lLMReq);
        HashMap hashMap = new HashMap();
        for (List<SqlExemplar> list : fewShotExemplars) {
            hashMap.put(generatePrompt(lLMReq, list), list);
        }
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        hashMap.keySet().parallelStream().forEach(prompt -> {
            keyPipelineLog.debug("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
            String text = ((AiMessage) getChatLanguageModel(lLMReq.getModelConfig()).generate(new ChatMessage[]{prompt.toUserMessage()}).content()).text();
            concurrentHashMap.put(prompt, text);
            keyPipelineLog.debug("OnePassSCSqlGenStrategy modelResp:\n{}", text);
        });
        Pair<String, Map<String, Double>> selfConsistencyVote = ResponseHelper.selfConsistencyVote(Lists.newArrayList(concurrentHashMap.values()));
        LLMResp lLMResp = new LLMResp();
        lLMResp.setQuery(this.promptHelper.buildAugmentedQuestion(lLMReq));
        lLMResp.setDbSchema(this.promptHelper.buildSchemaStr(lLMReq));
        lLMResp.setSqlOutput((String) selfConsistencyVote.getLeft());
        lLMResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(fewShotExemplars.get(0), (Map) selfConsistencyVote.getRight()));
        return lLMResp;
    }

    private Prompt generatePrompt(LLMReq lLMReq, List<SqlExemplar> list) {
        StringBuilder sb = new StringBuilder();
        for (SqlExemplar sqlExemplar : list) {
            sb.append(String.format("#Question:%s #Schema:%s #SQL:%s\n", sqlExemplar.getQuestion(), sqlExemplar.getDbSchema(), sqlExemplar.getSql()));
        }
        String buildSchemaStr = this.promptHelper.buildSchemaStr(lLMReq);
        String buildAugmentedQuestion = this.promptHelper.buildAugmentedQuestion(lLMReq);
        HashMap hashMap = new HashMap();
        hashMap.put("exemplar", sb);
        hashMap.put("question", buildAugmentedQuestion);
        hashMap.put("schema", buildSchemaStr);
        PromptConfig promptConfig = lLMReq.getPromptConfig();
        String str = INSTRUCTION;
        if (promptConfig != null && StringUtils.isNotBlank(promptConfig.getPromptTemplate())) {
            str = promptConfig.getPromptTemplate();
        }
        return PromptTemplate.from(str).apply(hashMap);
    }

    public void afterPropertiesSet() {
        SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this);
    }
}
