package com.tencent.supersonic.headless.chat.corrector;

import com.tencent.supersonic.common.jsqlparser.AggregateEnum;
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.class */
public class SchemaCorrector extends BaseSemanticCorrector {
    private static final Logger log = LoggerFactory.getLogger(SchemaCorrector.class);

    @Override // com.tencent.supersonic.headless.chat.corrector.BaseSemanticCorrector
    public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
        correctAggFunction(semanticParseInfo);
        replaceAlias(semanticParseInfo);
        updateFieldNameByLinkingValue(semanticParseInfo);
        updateFieldValueByLinkingValue(semanticParseInfo);
        correctFieldName(chatQueryContext, semanticParseInfo);
    }

    private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
        Map aggregateEnum = AggregateEnum.getAggregateEnum();
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        sqlInfo.setCorrectedS2SQL(SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectedS2SQL(), aggregateEnum));
    }

    private void replaceAlias(SemanticParseInfo semanticParseInfo) {
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        sqlInfo.setCorrectedS2SQL(SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectedS2SQL()));
    }

    private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
        Map<String, String> fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        sqlInfo.setCorrectedS2SQL(SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap));
    }

    private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
        List<LLMReq.ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
        if (CollectionUtils.isEmpty(linkingValues)) {
            return;
        }
        Map map = (Map) linkingValues.stream().collect(Collectors.groupingBy((v0) -> {
            return v0.getFieldValue();
        }, Collectors.mapping((v0) -> {
            return v0.getFieldName();
        }, Collectors.toSet())));
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        sqlInfo.setCorrectedS2SQL(SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectedS2SQL(), map));
    }

    private List<LLMReq.ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
        Object obj = semanticParseInfo.getProperties().get("CONTEXT");
        if (Objects.isNull(obj)) {
            return null;
        }
        ParseResult parseResult = (ParseResult) JsonUtil.toObject(JsonUtil.toString(obj), ParseResult.class);
        if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
            return null;
        }
        return parseResult.getLinkingValues();
    }

    private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
        List<LLMReq.ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
        if (CollectionUtils.isEmpty(linkingValues)) {
            return;
        }
        Map map = (Map) linkingValues.stream().collect(Collectors.groupingBy((v0) -> {
            return v0.getFieldName();
        }, Collectors.mapping((v0) -> {
            return v0.getFieldValue();
        }, Collectors.toMap(str -> {
            return str;
        }, str2 -> {
            return str2;
        }, (str3, str4) -> {
            return str4;
        }))));
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        sqlInfo.setCorrectedS2SQL(SqlReplaceHelper.replaceValue(sqlInfo.getCorrectedS2SQL(), map, false));
    }

    public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
        SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
        String correctedS2SQL = sqlInfo.getCorrectedS2SQL();
        List whereExpressions = SqlSelectHelper.getWhereExpressions(correctedS2SQL);
        if (CollectionUtils.isEmpty(whereExpressions)) {
            return;
        }
        List<LLMReq.ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
        Set<String> dimensions = getDimensions(semanticParseInfo.getDataSetId(), chatQueryContext.getSemanticSchema());
        if (CollectionUtils.isEmpty(linkingValues)) {
            linkingValues = new ArrayList();
        }
        Set set = (Set) linkingValues.stream().map(elementValue -> {
            return elementValue.getFieldName();
        }).collect(Collectors.toSet());
        sqlInfo.setCorrectedS2SQL(SqlRemoveHelper.removeWhereCondition(correctedS2SQL, (Set) whereExpressions.stream().filter(fieldExpression -> {
            return StringUtils.isBlank(fieldExpression.getFunction());
        }).filter(fieldExpression2 -> {
            return !TimeDimensionEnum.containsTimeDimension(fieldExpression2.getFieldName());
        }).filter(fieldExpression3 -> {
            return FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression3.getOperator());
        }).filter(fieldExpression4 -> {
            return dimensions.contains(fieldExpression4.getFieldName());
        }).filter(fieldExpression5 -> {
            return !DateUtils.isAnyDateString(fieldExpression5.getFieldValue().toString());
        }).filter(fieldExpression6 -> {
            return !set.contains(fieldExpression6.getFieldName());
        }).map(fieldExpression7 -> {
            return fieldExpression7.getFieldName();
        }).collect(Collectors.toSet())));
    }
}
