/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.headless.chat.corrector;

import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

public abstract class BaseSemanticCorrector
implements SemanticCorrector {
    private static final Logger log = LoggerFactory.getLogger(BaseSemanticCorrector.class);

    @Override
    public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
        try {
            if (StringUtils.isBlank((CharSequence)semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
                return;
            }
            this.doCorrect(chatQueryContext, semanticParseInfo);
            log.debug("sqlCorrection:{} sql:{}", (Object)this.getClass().getSimpleName(), (Object)semanticParseInfo.getSqlInfo());
        }
        catch (Exception e) {
            log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), (Throwable)e);
        }
    }

    public abstract void doCorrect(ChatQueryContext var1, SemanticParseInfo var2);

    protected Map<String, String> getFieldNameMap(ChatQueryContext chatQueryContext, Long dataSetId) {
        SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
        ArrayList dbAllFields = new ArrayList();
        dbAllFields.addAll(semanticSchema.getMetrics());
        dbAllFields.addAll(semanticSchema.getDimensions());
        Map<String, String> result = dbAllFields.stream().filter(entry -> dataSetId.equals(entry.getDataSet())).flatMap(schemaElement -> {
            HashSet<String> elements = new HashSet<String>();
            elements.add(schemaElement.getName());
            if (!CollectionUtils.isEmpty((Collection)schemaElement.getAlias())) {
                elements.addAll(schemaElement.getAlias());
            }
            return elements.stream();
        }).collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
        result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
        result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
        result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
        result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
        result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
        result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
        return result;
    }

    protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
        String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
        Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
        List<SchemaElement> metrics = this.getMetricElements(chatQueryContext, dataSetId);
        Map<String, String> metricToAggregate = metrics.stream().map(schemaElement -> {
            if (Objects.isNull(schemaElement.getDefaultAgg())) {
                schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
            }
            return schemaElement;
        }).flatMap(schemaElement -> {
            HashSet<String> elements = new HashSet<String>();
            elements.add(schemaElement.getName());
            if (!CollectionUtils.isEmpty((Collection)schemaElement.getAlias())) {
                elements.addAll(schemaElement.getAlias());
            }
            return elements.stream().map(element -> Pair.of((Object)element, (Object)schemaElement.getDefaultAgg()));
        }).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
        if (CollectionUtils.isEmpty(metricToAggregate)) {
            return;
        }
        String aggregateSql = SqlAddHelper.addAggregateToField((String)correctS2SQL, metricToAggregate);
        semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql);
    }

    protected List<SchemaElement> getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) {
        SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
        return semanticSchema.getMetrics(dataSetId);
    }

    protected Set<String> getDimensions(Long dataSetId, SemanticSchema semanticSchema) {
        Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream().flatMap(schemaElement -> {
            HashSet<String> elements = new HashSet<String>();
            elements.add(schemaElement.getName());
            if (!CollectionUtils.isEmpty((Collection)schemaElement.getAlias())) {
                elements.addAll(schemaElement.getAlias());
            }
            return elements.stream();
        }).collect(Collectors.toSet());
        dimensions.add(TimeDimensionEnum.DAY.getChName());
        return dimensions;
    }
}

