/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.headless.chat.corrector;

import com.tencent.supersonic.common.jsqlparser.AggregateEnum;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.BaseSemanticCorrector;
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

public class SchemaCorrector
extends BaseSemanticCorrector {
    private static final Logger log = LoggerFactory.getLogger(SchemaCorrector.class);

    @Override
    public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
        this.correctAggFunction(semanticParseInfo);
        this.replaceAlias(semanticParseInfo);
        this.updateFieldNameByLinkingValue(semanticParseInfo);
        this.updateFieldValueByLinkingValue(semanticParseInfo);
        this.correctFieldName(chatQueryContext, semanticParseInfo);
    }

    private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
        Map aggregateEnum = AggregateEnum.getAggregateEnum();
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        String sql = SqlReplaceHelper.replaceFunction((String)sqlInfo.getCorrectedS2SQL(), (Map)aggregateEnum);
        sqlInfo.setCorrectedS2SQL(sql);
    }

    private void replaceAlias(SemanticParseInfo semanticParseInfo) {
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        String replaceAlias = SqlReplaceHelper.replaceAlias((String)sqlInfo.getCorrectedS2SQL());
        sqlInfo.setCorrectedS2SQL(replaceAlias);
    }

    private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
        Map<String, String> fieldNameMap = this.getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        String sql = SqlReplaceHelper.replaceFields((String)sqlInfo.getCorrectedS2SQL(), fieldNameMap);
        sqlInfo.setCorrectedS2SQL(sql);
    }

    private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
        List<LLMReq.ElementValue> linking = this.getLinkingValues(semanticParseInfo);
        if (CollectionUtils.isEmpty(linking)) {
            return;
        }
        Map fieldValueToFieldNames = linking.stream().collect(Collectors.groupingBy(LLMReq.ElementValue::getFieldValue, Collectors.mapping(LLMReq.ElementValue::getFieldName, Collectors.toSet())));
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        String sql = SqlReplaceHelper.replaceFieldNameByValue((String)sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames);
        sqlInfo.setCorrectedS2SQL(sql);
    }

    private List<LLMReq.ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
        Object context = semanticParseInfo.getProperties().get("CONTEXT");
        if (Objects.isNull(context)) {
            return null;
        }
        ParseResult parseResult = (ParseResult)JsonUtil.toObject((String)JsonUtil.toString(context), ParseResult.class);
        if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
            return null;
        }
        return parseResult.getLinkingValues();
    }

    private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
        List<LLMReq.ElementValue> linking = this.getLinkingValues(semanticParseInfo);
        if (CollectionUtils.isEmpty(linking)) {
            return;
        }
        Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(Collectors.groupingBy(LLMReq.ElementValue::getFieldName, Collectors.mapping(LLMReq.ElementValue::getFieldValue, Collectors.toMap(oldValue -> oldValue, newValue -> newValue, (existingValue, newValue) -> newValue))));
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        String sql = SqlReplaceHelper.replaceValue((String)sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, (boolean)false);
        sqlInfo.setCorrectedS2SQL(sql);
    }

    public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        String correctS2SQL = sqlInfo.getCorrectedS2SQL();
        List whereExpressionList = SqlSelectHelper.getWhereExpressions((String)correctS2SQL);
        if (CollectionUtils.isEmpty((Collection)whereExpressionList)) {
            return;
        }
        List<LLMReq.ElementValue> linkingValues = this.getLinkingValues(semanticParseInfo);
        SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
        Set<String> dimensions = this.getDimensions(semanticParseInfo.getDataSetId(), semanticSchema);
        if (CollectionUtils.isEmpty(linkingValues)) {
            linkingValues = new ArrayList<LLMReq.ElementValue>();
        }
        Set linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName()).collect(Collectors.toSet());
        Set removeFieldNames = whereExpressionList.stream().filter(fieldExpression -> StringUtils.isBlank((CharSequence)fieldExpression.getFunction())).filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension((String)fieldExpression.getFieldName())).filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator())).filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName())).filter(fieldExpression -> !DateUtils.isAnyDateString((String)fieldExpression.getFieldValue().toString())).filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName())).map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
        String sql = SqlRemoveHelper.removeWhereCondition((String)correctS2SQL, removeFieldNames);
        sqlInfo.setCorrectedS2SQL(sql);
    }
}

