package com.tencent.supersonic.headless.server.processor;

import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:com/tencent/supersonic/headless/server/processor/SqlInfoProcessor.class */
public class SqlInfoProcessor implements ResultProcessor {
    private static final Logger log = LoggerFactory.getLogger(SqlInfoProcessor.class);
    private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");

    @Override // com.tencent.supersonic.headless.server.processor.ResultProcessor
    public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
        long currentTimeMillis = System.currentTimeMillis();
        List candidateQueries = queryContext.getCandidateQueries();
        if (CollectionUtils.isEmpty(candidateQueries)) {
            return;
        }
        addSqlInfo(queryContext, (List<SemanticParseInfo>) candidateQueries.stream().map((v0) -> {
            return v0.getParseInfo();
        }).collect(Collectors.toList()));
        parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - currentTimeMillis);
    }

    private void addSqlInfo(QueryContext queryContext, List<SemanticParseInfo> list) {
        if (CollectionUtils.isEmpty(list)) {
            return;
        }
        list.forEach(semanticParseInfo -> {
            try {
                addSqlInfo(queryContext, semanticParseInfo);
            } catch (Exception e) {
                log.warn("get sql info failed:{}", semanticParseInfo, e);
            }
        });
    }

    private void addSqlInfo(QueryContext queryContext, SemanticParseInfo semanticParseInfo) throws Exception {
        SemanticQuery createQuery = QueryManager.createQuery(semanticParseInfo.getQueryMode());
        if (Objects.isNull(createQuery)) {
            return;
        }
        createQuery.setParseInfo(semanticParseInfo);
        ExplainResp explain = ((SemanticLayerService) ContextUtils.getBean(SemanticLayerService.class)).explain(ExplainSqlReq.builder().queryReq(createQuery.buildSemanticQueryReq()).queryTypeEnum(QueryMethod.SQL).build(), queryContext.getUser());
        String sql = explain.getSql();
        if (StringUtils.isBlank(sql)) {
            return;
        }
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        if (createQuery instanceof LLMSqlQuery) {
            keyPipelineLog.info("SqlInfoProcessor results:\nParsed S2SQL:{}\nCorrected S2SQL:{}\nFinal SQL:{}", new Object[]{sqlInfo.getS2SQL(), sqlInfo.getCorrectS2SQL(), sql});
        }
        sqlInfo.setQuerySQL(sql);
        sqlInfo.setSourceId(explain.getSourceId());
    }
}
