package com.tencent.supersonic.headless.chat.parser.llm;

import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

@Component
/* loaded from: input_file:com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.class */
public class PromptHelper {
    private static final Logger log = LoggerFactory.getLogger(PromptHelper.class);

    @Autowired
    private ParserConfig parserConfig;

    @Autowired
    private ExemplarService exemplarService;

    public List<List<SqlExemplar>> getFewShotExemplars(LLMReq lLMReq) {
        int intValue = Integer.valueOf(this.parserConfig.getParameterValue(ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER)).intValue();
        int intValue2 = Integer.valueOf(this.parserConfig.getParameterValue(ParserConfig.PARSER_FEW_SHOT_NUMBER)).intValue();
        int intValue3 = Integer.valueOf(this.parserConfig.getParameterValue(ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER)).intValue();
        ArrayList newArrayList = Lists.newArrayList();
        lLMReq.getDynamicExemplars().stream().forEach(sqlExemplar -> {
            newArrayList.add(sqlExemplar);
        });
        int size = intValue - lLMReq.getDynamicExemplars().size();
        if (size > 0) {
            newArrayList.addAll(this.exemplarService.recallExemplars(lLMReq.getAppendQueryText(), size));
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < intValue3; i++) {
            ArrayList arrayList2 = new ArrayList(newArrayList);
            Collections.shuffle(arrayList2);
            arrayList.add(arrayList2.subList(0, intValue2));
        }
        return arrayList;
    }

    public String buildAugmentedQuestion(LLMReq lLMReq) {
        List<LLMReq.ElementValue> linking = lLMReq.getLinking();
        String currentDate = lLMReq.getCurrentDate();
        String priorExts = lLMReq.getPriorExts();
        ArrayList arrayList = new ArrayList();
        for (LLMReq.ElementValue elementValue : linking) {
            arrayList.add("‘" + elementValue.getFieldValue() + "‘是一个‘" + elementValue.getFieldName() + "‘");
        }
        return String.format("%s (补充信息:%s;%s;%s;%s)", lLMReq.getAppendQueryText(), String.join("，", arrayList), "当前的日期是" + currentDate, buildTermStr(lLMReq), priorExts);
    }

    public String buildSchemaStr(LLMReq lLMReq) {
        String dataSetName = lLMReq.getSchema().getDataSetName();
        StringBuilder sb = new StringBuilder();
        StringBuilder sb2 = new StringBuilder();
        lLMReq.getSchema().getMetrics().stream().forEach(schemaElement -> {
            sb.append(schemaElement.getName());
            if (StringUtils.isNotEmpty(schemaElement.getDescription())) {
                sb.append(" COMMENT '" + schemaElement.getDescription() + "'");
            }
            if (StringUtils.isNotEmpty(schemaElement.getDefaultAgg())) {
                sb.append(" AGGREGATE '" + schemaElement.getDefaultAgg().toUpperCase() + "'");
            }
            sb.append(",");
        });
        lLMReq.getSchema().getDimensions().stream().forEach(schemaElement2 -> {
            sb2.append(schemaElement2.getName());
            if (StringUtils.isNotEmpty(schemaElement2.getDescription())) {
                sb2.append(" COMMENT '" + schemaElement2.getDescription() + "'");
            }
            sb2.append(",");
        });
        return String.format("Table: %s, Metrics: [%s], Dimensions: [%s]", dataSetName, sb, sb2);
    }

    private String buildTermStr(LLMReq lLMReq) {
        List<LLMReq.Term> terms = lLMReq.getSchema().getTerms();
        StringBuilder sb = new StringBuilder();
        if (!CollectionUtils.isEmpty(terms)) {
            sb.append("相关业务术语：");
            for (int i = 0; i < terms.size(); i++) {
                LLMReq.Term term = terms.get(i);
                String name = term.getName();
                String description = term.getDescription();
                List<String> alias = term.getAlias();
                sb.append(String.format("%d.<%s>是业务术语%s%s；", Integer.valueOf(i + 1), name, StringUtils.isBlank(description) ? "" : String.format("，它通常是指<%s>", description), CollectionUtils.isEmpty(alias) ? "" : String.format("，类似的表达还有%s", alias)));
            }
            if (sb.length() > 0) {
                sb.setLength(sb.length() - 1);
            }
        }
        return sb.toString();
    }
}
