package com.tencent.supersonic.headless.server.facade.service.impl;

import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.web.service.DataSetService;
import com.tencent.supersonic.headless.server.web.service.SchemaService;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
/* loaded from: input_file:com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.class */
public class S2ChatLayerService implements ChatLayerService {
    private static final Logger log = LoggerFactory.getLogger(S2ChatLayerService.class);

    @Autowired
    private SchemaService schemaService;

    @Autowired
    private DataSetService dataSetService;

    @Autowired
    private ChatWorkflowEngine chatWorkflowEngine;

    @Override // com.tencent.supersonic.headless.server.facade.service.ChatLayerService
    public MapResp performMapping(QueryNLReq queryNLReq) {
        MapResp mapResp = new MapResp();
        ChatQueryContext buildChatQueryContext = buildChatQueryContext(queryNLReq);
        ComponentFactory.getSchemaMappers().forEach(schemaMapper -> {
            schemaMapper.map(buildChatQueryContext);
        });
        mapResp.setMapInfo(buildChatQueryContext.getMapInfo());
        mapResp.setQueryText(queryNLReq.getQueryText());
        return mapResp;
    }

    @Override // com.tencent.supersonic.headless.server.facade.service.ChatLayerService
    public MapInfoResp map(QueryMapReq queryMapReq) {
        QueryNLReq queryNLReq = new QueryNLReq();
        BeanUtils.copyProperties(queryMapReq, queryNLReq);
        Set<Long> set = (Set) this.dataSetService.getDataSets(queryMapReq.getDataSetNames(), queryMapReq.getUser()).stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
        queryNLReq.setDataSetIds(set);
        MapResp performMapping = performMapping(queryNLReq);
        set.retainAll(performMapping.getMapInfo().getDataSetElementMatches().keySet());
        return convert(performMapping, queryMapReq.getTopN(), set);
    }

    @Override // com.tencent.supersonic.headless.server.facade.service.ChatLayerService
    public ParseResp performParsing(QueryNLReq queryNLReq) {
        ParseResp parseResp = new ParseResp(queryNLReq.getQueryText());
        ChatQueryContext buildChatQueryContext = buildChatQueryContext(queryNLReq);
        this.chatWorkflowEngine.execute(buildChatQueryContext, parseResp);
        parseResp.setSelectedParses((List) buildChatQueryContext.getCandidateQueries().stream().map((v0) -> {
            return v0.getParseInfo();
        }).collect(Collectors.toList()));
        return parseResp;
    }

    public ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) {
        SemanticSchema semanticSchema = this.schemaService.getSemanticSchema();
        ChatQueryContext build = ChatQueryContext.builder().queryFilters(queryNLReq.getQueryFilters()).semanticSchema(semanticSchema).candidateQueries(new ArrayList()).mapInfo(new SchemaMapInfo()).modelIdToDataSetIds(this.dataSetService.getModelIdToDataSetIds()).text2SQLType(queryNLReq.getText2SQLType()).mapModeEnum(queryNLReq.getMapModeEnum()).dataSetIds(queryNLReq.getDataSetIds()).build();
        BeanUtils.copyProperties(queryNLReq, build);
        return build;
    }

    @Override // com.tencent.supersonic.headless.server.facade.service.ChatLayerService
    public void correct(QuerySqlReq querySqlReq, User user) {
        querySqlReq.setSql(correctSqlReq(querySqlReq, user).getSqlInfo().getCorrectedS2SQL());
    }

    @Override // com.tencent.supersonic.headless.server.facade.service.ChatLayerService
    public SqlEvaluation validate(QuerySqlReq querySqlReq, User user) {
        return correctSqlReq(querySqlReq, user).getSqlEvaluation();
    }

    private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) {
        ChatQueryContext chatQueryContext = new ChatQueryContext();
        SemanticSchema semanticSchema = this.schemaService.getSemanticSchema();
        chatQueryContext.setSemanticSchema(semanticSchema);
        SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
        SqlInfo sqlInfo = new SqlInfo();
        sqlInfo.setCorrectedS2SQL(querySqlReq.getSql());
        sqlInfo.setParsedS2SQL(querySqlReq.getSql());
        semanticParseInfo.setSqlInfo(sqlInfo);
        semanticParseInfo.setQueryType(QueryType.DETAIL);
        Long dataSetId = querySqlReq.getDataSetId();
        if (Objects.isNull(dataSetId)) {
            dataSetId = this.dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user);
        }
        semanticParseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
        ComponentFactory.getSemanticCorrectors().forEach(semanticCorrector -> {
            if ((semanticCorrector instanceof GrammarCorrector) || (semanticCorrector instanceof SchemaCorrector)) {
                return;
            }
            semanticCorrector.correct(chatQueryContext, semanticParseInfo);
        });
        log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL());
        return semanticParseInfo;
    }

    private MapInfoResp convert(MapResp mapResp, Integer num, Set<Long> set) {
        MapInfoResp mapInfoResp = new MapInfoResp();
        if (Objects.isNull(mapResp)) {
            return mapInfoResp;
        }
        BeanUtils.copyProperties(mapResp, mapInfoResp);
        MetaFilter metaFilter = new MetaFilter();
        metaFilter.setIds(new ArrayList(set));
        Map<Long, DataSetResp> map = (Map) this.dataSetService.getDataSetList(metaFilter).stream().collect(Collectors.toMap((v0) -> {
            return v0.getId();
        }, dataSetResp -> {
            return dataSetResp;
        }));
        mapInfoResp.setDataSetMapInfo(getDataSetInfo(mapResp.getMapInfo(), map, num));
        mapInfoResp.setTerms(getTerms(mapResp.getMapInfo(), map));
        return mapInfoResp;
    }

    private Map<String, DataSetMapInfo> getDataSetInfo(SchemaMapInfo schemaMapInfo, Map<Long, DataSetResp> map, Integer num) {
        HashMap hashMap = new HashMap();
        Map<Long, List<SchemaElementMatch>> mapFields = getMapFields(schemaMapInfo, map);
        Map<Long, List<SchemaElementMatch>> topFields = getTopFields(num, schemaMapInfo, map);
        for (Long l : schemaMapInfo.getDataSetElementMatches().keySet()) {
            DataSetResp dataSetResp = map.get(l);
            if (dataSetResp != null && !CollectionUtils.isEmpty(mapFields.get(l))) {
                DataSetMapInfo dataSetMapInfo = new DataSetMapInfo();
                dataSetMapInfo.setMapFields(mapFields.getOrDefault(l, Lists.newArrayList()));
                dataSetMapInfo.setTopFields(topFields.getOrDefault(l, Lists.newArrayList()));
                dataSetMapInfo.setName(dataSetResp.getName());
                dataSetMapInfo.setDescription(dataSetResp.getDescription());
                hashMap.put(dataSetMapInfo.getName(), dataSetMapInfo);
            }
        }
        return hashMap;
    }

    private Map<Long, List<SchemaElementMatch>> getMapFields(SchemaMapInfo schemaMapInfo, Map<Long, DataSetResp> map) {
        HashMap hashMap = new HashMap();
        for (Map.Entry entry : schemaMapInfo.getDataSetElementMatches().entrySet()) {
            List list = (List) ((List) entry.getValue()).stream().filter(schemaElementMatch -> {
                return !SchemaElementType.TERM.equals(schemaElementMatch.getElement().getType());
            }).collect(Collectors.toList());
            if (CollectionUtils.isNotEmpty(list) && map.containsKey(entry.getKey())) {
                hashMap.put((Long) entry.getKey(), list);
            }
        }
        return hashMap;
    }

    private Map<Long, List<SchemaElementMatch>> getTopFields(Integer num, SchemaMapInfo schemaMapInfo, Map<Long, DataSetResp> map) {
        HashMap hashMap = new HashMap();
        if (0 == num.intValue()) {
            return hashMap;
        }
        SemanticSchema semanticSchema = this.schemaService.getSemanticSchema();
        for (Map.Entry entry : schemaMapInfo.getDataSetElementMatches().entrySet()) {
            Long l = (Long) entry.getKey();
            List list = (List) entry.getValue();
            DataSetResp dataSetResp = map.get(l);
            if (dataSetResp != null && !CollectionUtils.isEmpty(list)) {
                String name = dataSetResp.getName();
                Set set = (Set) semanticSchema.getDimensions(l).stream().sorted(Comparator.comparing((v0) -> {
                    return v0.getUseCnt();
                }).reversed()).limit(num.intValue() - 1).map(mergeFunction()).collect(Collectors.toSet());
                set.add(getTimeDimension(l, name));
                set.addAll((Set) semanticSchema.getMetrics(l).stream().sorted(Comparator.comparing((v0) -> {
                    return v0.getUseCnt();
                }).reversed()).limit(num.intValue()).map(mergeFunction()).collect(Collectors.toSet()));
                hashMap.put(l, new ArrayList(set));
            }
        }
        return hashMap;
    }

    private Map<String, List<SchemaElementMatch>> getTerms(SchemaMapInfo schemaMapInfo, Map<Long, DataSetResp> map) {
        HashMap hashMap = new HashMap();
        for (Map.Entry entry : schemaMapInfo.getDataSetElementMatches().entrySet()) {
            DataSetResp dataSetResp = map.get(entry.getKey());
            if (dataSetResp != null) {
                hashMap.put(dataSetResp.getName(), (List) ((List) entry.getValue()).stream().filter(schemaElementMatch -> {
                    return SchemaElementType.TERM.equals(schemaElementMatch.getElement().getType());
                }).collect(Collectors.toList()));
            }
        }
        return hashMap;
    }

    private SchemaElementMatch getTimeDimension(Long l, String str) {
        return SchemaElementMatch.builder().element(SchemaElement.builder().dataSet(l).dataSetName(str).type(SchemaElementType.DIMENSION).bizName(TimeDimensionEnum.DAY.getName()).build()).detectWord(TimeDimensionEnum.DAY.getChName()).word(TimeDimensionEnum.DAY.getChName()).similarity(1.0d).frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build();
    }

    private Function<SchemaElement, SchemaElementMatch> mergeFunction() {
        return schemaElement -> {
            return SchemaElementMatch.builder().element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY).word(schemaElement.getName()).similarity(1.0d).detectWord(schemaElement.getName()).build();
        };
    }
}
