/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.headless.chat.parser.llm;

import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.llm.DataSetMatchResult;
import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HeuristicDataSetResolver
implements DataSetResolver {
    private static final Logger log = LoggerFactory.getLogger(HeuristicDataSetResolver.class);

    protected static Long selectDataSetBySchemaElementMatchScore(Map<Long, SemanticQuery> dataSetQueryModes, SchemaMapInfo schemaMap) {
        Long dataSetIdByDataSetCount = HeuristicDataSetResolver.getDataSetIdByMatchDataSetScore(schemaMap);
        if (Objects.nonNull(dataSetIdByDataSetCount)) {
            log.info("selectDataSet by dataSet count:{}", (Object)dataSetIdByDataSetCount);
            return dataSetIdByDataSetCount;
        }
        Map<Long, DataSetMatchResult> dataSetTypeMap = HeuristicDataSetResolver.getDataSetTypeMap(schemaMap);
        if (dataSetTypeMap.size() == 1) {
            Long dataSetSelect = new ArrayList<Map.Entry<Long, DataSetMatchResult>>(dataSetTypeMap.entrySet()).get(0).getKey();
            if (dataSetQueryModes.containsKey(dataSetSelect)) {
                log.info("selectDataSet with only one DataSet [{}]", (Object)dataSetSelect);
                return dataSetSelect;
            }
        } else {
            Map.Entry maxDataSet = dataSetTypeMap.entrySet().stream().filter(entry -> dataSetQueryModes.containsKey(entry.getKey())).sorted((o1, o2) -> {
                int difference = ((DataSetMatchResult)o2.getValue()).getCount() - ((DataSetMatchResult)o1.getValue()).getCount();
                if (difference == 0) {
                    return (int)((((DataSetMatchResult)o2.getValue()).getMaxSimilarity() - ((DataSetMatchResult)o1.getValue()).getMaxSimilarity()) * 100.0);
                }
                return difference;
            }).findFirst().orElse(null);
            if (maxDataSet != null) {
                log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey());
                return (Long)maxDataSet.getKey();
            }
        }
        return null;
    }

    private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) {
        Map dataSetElementMatches = schemaMap.getDataSetElementMatches();
        HashMap<Long, Double> dataSetIdToDataSetScore = new HashMap<Long, Double>();
        if (Objects.nonNull(dataSetElementMatches)) {
            for (Map.Entry dataSetElementMatch : dataSetElementMatches.entrySet()) {
                Long dataSetId = (Long)dataSetElementMatch.getKey();
                List dataSetMatchesScore = ((List)dataSetElementMatch.getValue()).stream().filter(elementMatch -> elementMatch.getSimilarity() >= 1.0).filter(elementMatch -> SchemaElementType.DATASET.equals((Object)elementMatch.getElement().getType())).map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
                if (CollectionUtils.isEmpty(dataSetMatchesScore)) continue;
                double score = dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
                dataSetIdToDataSetScore.put(dataSetId, score);
            }
            Map.Entry maxDataSetScore = dataSetIdToDataSetScore.entrySet().stream().max(Comparator.comparingDouble(Map.Entry::getValue)).orElse(null);
            log.info("maxDataSetCount:{},dataSetIdToDataSetCount:{}", (Object)maxDataSetScore, dataSetIdToDataSetScore);
            if (Objects.nonNull(maxDataSetScore)) {
                return (Long)maxDataSetScore.getKey();
            }
        }
        return null;
    }

    public static Map<Long, DataSetMatchResult> getDataSetTypeMap(SchemaMapInfo schemaMap) {
        HashMap<Long, DataSetMatchResult> dataSetCount = new HashMap<Long, DataSetMatchResult>();
        for (Map.Entry entry : schemaMap.getDataSetElementMatches().entrySet()) {
            List schemaElementMatches = schemaMap.getMatchedElements((Long)entry.getKey());
            if (schemaElementMatches == null || schemaElementMatches.size() <= 0) continue;
            if (!dataSetCount.containsKey(entry.getKey())) {
                dataSetCount.put((Long)entry.getKey(), new DataSetMatchResult());
            }
            DataSetMatchResult dataSetMatchResult = (DataSetMatchResult)dataSetCount.get(entry.getKey());
            HashSet schemaElementTypes = new HashSet();
            schemaElementMatches.stream().forEach(schemaElementMatch -> schemaElementTypes.add(schemaElementMatch.getElement().getType()));
            SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream().sorted((o1, o2) -> (int)((o2.getSimilarity() - o1.getSimilarity()) * 100.0)).findFirst().orElse(null);
            if (schemaElementMatchMax != null) {
                dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
            }
            dataSetMatchResult.setCount(schemaElementTypes.size());
        }
        return dataSetCount;
    }

    @Override
    public Long resolve(ChatQueryContext chatQueryContext, Set<Long> agentDataSetIds) {
        SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
        Set matchedDataSets = mapInfo.getMatchedDataSetInfos();
        if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
            matchedDataSets.retainAll(agentDataSetIds);
        }
        HashMap<Long, SemanticQuery> dataSetQueryModes = new HashMap<Long, SemanticQuery>();
        for (Long dataSetIds : matchedDataSets) {
            dataSetQueryModes.put(dataSetIds, null);
        }
        if (dataSetQueryModes.size() == 1) {
            return (Long)dataSetQueryModes.keySet().stream().findFirst().get();
        }
        return HeuristicDataSetResolver.selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo);
    }
}

