/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.headless.chat.parser;

import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class QueryTypeParser
implements SemanticParser {
    private static final Logger log = LoggerFactory.getLogger(QueryTypeParser.class);

    @Override
    public void parse(ChatQueryContext chatQueryContext) {
        List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
        User user = chatQueryContext.getUser();
        for (SemanticQuery semanticQuery : candidateQueries) {
            Long dataSetId = semanticQuery.getParseInfo().getDataSetId();
            DataSetSchema dataSetSchema = (DataSetSchema)chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
            semanticQuery.initS2Sql(dataSetSchema, user);
            QueryType queryType = this.getQueryType(chatQueryContext, semanticQuery);
            semanticQuery.getParseInfo().setQueryType(queryType);
        }
    }

    private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) {
        SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
        SqlInfo sqlInfo = parseInfo.getSqlInfo();
        if (Objects.isNull(sqlInfo) || StringUtils.isBlank((CharSequence)sqlInfo.getParsedS2SQL())) {
            return QueryType.DETAIL;
        }
        Long dataSetId = parseInfo.getDataSetId();
        SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
        if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
            Set tags;
            Set ids;
            List whereFields = SqlSelectHelper.getWhereFields((String)sqlInfo.getParsedS2SQL());
            List<String> whereFilterByTimeFields = QueryTypeParser.filterByTimeFields(whereFields);
            if (CollectionUtils.isNotEmpty(whereFilterByTimeFields) && CollectionUtils.isNotEmpty(ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName).collect(Collectors.toSet()))) {
                if (ids.stream().anyMatch(whereFilterByTimeFields::contains)) {
                    return QueryType.ID;
                }
            }
            List selectFields = SqlSelectHelper.getSelectFields((String)sqlInfo.getParsedS2SQL());
            selectFields.addAll(whereFields);
            List<String> selectWhereFilterByTimeFields = QueryTypeParser.filterByTimeFields(selectFields);
            if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields) && CollectionUtils.isNotEmpty(tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName).collect(Collectors.toSet())) && tags.containsAll(selectWhereFilterByTimeFields)) {
                return QueryType.DETAIL;
            }
        }
        if (QueryTypeParser.selectContainsMetric(sqlInfo, dataSetId, semanticSchema)) {
            return QueryType.METRIC;
        }
        return QueryType.DETAIL;
    }

    private static List<String> filterByTimeFields(List<String> whereFields) {
        List<String> selectAndWhereFilterByTimeFields = whereFields.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension((String)field)).collect(Collectors.toList());
        return selectAndWhereFilterByTimeFields;
    }

    private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
        List selectFields = SqlSelectHelper.getSelectFields((String)sqlInfo.getParsedS2SQL());
        List metrics = semanticSchema.getMetrics(dataSetId);
        if (CollectionUtils.isNotEmpty((Collection)metrics)) {
            Set metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
            return selectFields.stream().anyMatch(metricNameSet::contains);
        }
        return false;
    }
}

