/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.headless.chat.parser.llm;

import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import java.lang.invoke.CallSite;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

@Component
public class PromptHelper {
    private static final Logger log = LoggerFactory.getLogger(PromptHelper.class);
    @Autowired
    private ParserConfig parserConfig;
    @Autowired
    private ExemplarService exemplarService;

    public List<List<SqlExemplar>> getFewShotExemplars(LLMReq llmReq) {
        int exemplarRecallNumber = Integer.valueOf(this.parserConfig.getParameterValue(ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER));
        int fewShotNumber = Integer.valueOf(this.parserConfig.getParameterValue(ParserConfig.PARSER_FEW_SHOT_NUMBER));
        int selfConsistencyNumber = Integer.valueOf(this.parserConfig.getParameterValue(ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER));
        ArrayList exemplars = Lists.newArrayList();
        llmReq.getDynamicExemplars().stream().forEach(e -> exemplars.add(e));
        int recallSize = exemplarRecallNumber - llmReq.getDynamicExemplars().size();
        if (recallSize > 0) {
            exemplars.addAll(this.exemplarService.recallExemplars(llmReq.getQueryText(), recallSize));
        }
        ArrayList<List<SqlExemplar>> results = new ArrayList<List<SqlExemplar>>();
        for (int i = 0; i < selfConsistencyNumber; ++i) {
            ArrayList shuffledList = new ArrayList(exemplars);
            Collections.shuffle(shuffledList);
            results.add(shuffledList.subList(0, fewShotNumber));
        }
        return results;
    }

    public String buildAugmentedQuestion(LLMReq llmReq) {
        List<LLMReq.ElementValue> linkedValues = llmReq.getLinking();
        String currentDate = llmReq.getCurrentDate();
        String priorExts = llmReq.getPriorExts();
        ArrayList<CallSite> priorLinkingList = new ArrayList<CallSite>();
        for (LLMReq.ElementValue value : linkedValues) {
            String fieldName = value.getFieldName();
            String fieldValue = value.getFieldValue();
            priorLinkingList.add((CallSite)((Object)("\u2018" + fieldValue + "\u2018\u662f\u4e00\u4e2a\u2018" + fieldName + "\u2018")));
        }
        String currentDataStr = "\u5f53\u524d\u7684\u65e5\u671f\u662f" + currentDate;
        String linkingListStr = String.join((CharSequence)"\uff0c", priorLinkingList);
        String termStr = this.buildTermStr(llmReq);
        return String.format("%s (\u8865\u5145\u4fe1\u606f:%s;%s;%s;%s)", llmReq.getQueryText(), linkingListStr, currentDataStr, termStr, priorExts);
    }

    public String buildSchemaStr(LLMReq llmReq) {
        String tableStr = llmReq.getSchema().getDataSetName();
        StringBuilder metricStr = new StringBuilder();
        StringBuilder dimensionStr = new StringBuilder();
        llmReq.getSchema().getMetrics().stream().forEach(metric -> {
            metricStr.append(metric.getName());
            if (StringUtils.isNotEmpty((CharSequence)metric.getDescription())) {
                metricStr.append(" COMMENT '" + metric.getDescription() + "'");
            }
            if (StringUtils.isNotEmpty((CharSequence)metric.getDefaultAgg())) {
                metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
            }
            metricStr.append(",");
        });
        llmReq.getSchema().getDimensions().stream().forEach(dimension -> {
            dimensionStr.append(dimension.getName());
            if (StringUtils.isNotEmpty((CharSequence)dimension.getDescription())) {
                dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
            }
            dimensionStr.append(",");
        });
        String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
        return String.format(template, tableStr, metricStr, dimensionStr);
    }

    private String buildTermStr(LLMReq llmReq) {
        List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
        StringBuilder termsDesc = new StringBuilder();
        if (!CollectionUtils.isEmpty(terms)) {
            termsDesc.append("\u76f8\u5173\u4e1a\u52a1\u672f\u8bed\uff1a");
            for (int idx = 0; idx < terms.size(); ++idx) {
                LLMReq.Term term = terms.get(idx);
                String name = term.getName();
                String description = term.getDescription();
                List<String> alias = term.getAlias();
                String descPart = StringUtils.isBlank((CharSequence)description) ? "" : String.format("\uff0c\u5b83\u901a\u5e38\u662f\u6307<%s>", description);
                String aliasPart = CollectionUtils.isEmpty(alias) ? "" : String.format("\uff0c\u7c7b\u4f3c\u7684\u8868\u8fbe\u8fd8\u6709%s", alias);
                termsDesc.append(String.format("%d.<%s>\u662f\u4e1a\u52a1\u672f\u8bed%s%s\uff1b", idx + 1, name, descPart, aliasPart));
            }
            if (termsDesc.length() > 0) {
                termsDesc.setLength(termsDesc.length() - 1);
            }
        }
        return termsDesc.toString();
    }
}

