package com.tencent.supersonic.headless.chat.corrector;

import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.QueryContext;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.class */
public abstract class BaseSemanticCorrector implements SemanticCorrector {
    private static final Logger log = LoggerFactory.getLogger(BaseSemanticCorrector.class);

    @Override // com.tencent.supersonic.headless.chat.corrector.SemanticCorrector
    public void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
        try {
            if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
                return;
            }
            doCorrect(queryContext, semanticParseInfo);
            log.debug("sqlCorrection:{} sql:{}", getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
        } catch (Exception e) {
            log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
        }
    }

    public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);

    /* JADX INFO: Access modifiers changed from: protected */
    public Map<String, String> getFieldNameMap(QueryContext queryContext, Long l) {
        SemanticSchema semanticSchema = queryContext.getSemanticSchema();
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(semanticSchema.getMetrics());
        arrayList.addAll(semanticSchema.getDimensions());
        Map<String, String> map = (Map) arrayList.stream().filter(schemaElement -> {
            return l.equals(schemaElement.getDataSet());
        }).flatMap(schemaElement2 -> {
            HashSet hashSet = new HashSet();
            hashSet.add(schemaElement2.getName());
            if (!CollectionUtils.isEmpty(schemaElement2.getAlias())) {
                hashSet.addAll(schemaElement2.getAlias());
            }
            return hashSet.stream();
        }).collect(Collectors.toMap(str -> {
            return str;
        }, str2 -> {
            return str2;
        }, (str3, str4) -> {
            return str3;
        }));
        map.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
        map.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
        map.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
        map.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
        map.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
        map.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
        return map;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
        String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
        Map map = (Map) getMetricElements(queryContext, semanticParseInfo.getDataSet().getDataSet()).stream().map(schemaElement -> {
            if (Objects.isNull(schemaElement.getDefaultAgg())) {
                schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
            }
            return schemaElement;
        }).flatMap(schemaElement2 -> {
            HashSet hashSet = new HashSet();
            hashSet.add(schemaElement2.getName());
            if (!CollectionUtils.isEmpty(schemaElement2.getAlias())) {
                hashSet.addAll(schemaElement2.getAlias());
            }
            return hashSet.stream().map(str -> {
                return Pair.of(str, schemaElement2.getDefaultAgg());
            });
        }).collect(Collectors.toMap((v0) -> {
            return v0.getLeft();
        }, (v0) -> {
            return v0.getRight();
        }, (str, str2) -> {
            return str;
        }));
        if (CollectionUtils.isEmpty(map)) {
            return;
        }
        semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addAggregateToField(correctS2SQL, map));
    }

    protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long l) {
        return queryContext.getSemanticSchema().getMetrics(l);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Set<String> getDimensions(Long l, SemanticSchema semanticSchema) {
        Set<String> set = (Set) semanticSchema.getDimensions(l).stream().flatMap(schemaElement -> {
            HashSet hashSet = new HashSet();
            hashSet.add(schemaElement.getName());
            if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
                hashSet.addAll(schemaElement.getAlias());
            }
            return hashSet.stream();
        }).collect(Collectors.toSet());
        set.add(TimeDimensionEnum.DAY.getChName());
        return set;
    }
}
