package com.tencent.supersonic.headless.chat.knowledge;

import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
/* loaded from: input_file:com/tencent/supersonic/headless/chat/knowledge/MetaEmbeddingService.class */
public class MetaEmbeddingService {
    private static final Logger log = LoggerFactory.getLogger(MetaEmbeddingService.class);

    @Autowired
    private EmbeddingService embeddingService;

    @Autowired
    private EmbeddingConfig embeddingConfig;

    public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int i, Map<Long, List<Long>> map, Set<Long> set) {
        Set<Long> modelIds = NatureHelper.getModelIds(map, set);
        if (CollectionUtils.isNotEmpty(modelIds) && modelIds.size() == 1) {
            HashMap hashMap = new HashMap();
            hashMap.put("modelId", modelIds.stream().findFirst().get().toString());
            retrieveQuery.setFilterCondition(hashMap);
        }
        List<RetrieveQueryResult> retrieveQuery2 = this.embeddingService.retrieveQuery(this.embeddingConfig.getMetaCollectionName(), retrieveQuery, i);
        return CollectionUtils.isEmpty(retrieveQuery2) ? new ArrayList() : CollectionUtils.isEmpty(modelIds) ? retrieveQuery2 : (List) retrieveQuery2.stream().map(retrieveQueryResult -> {
            List retrieval = retrieveQueryResult.getRetrieval();
            if (CollectionUtils.isEmpty(retrieval)) {
                return retrieveQueryResult;
            }
            retrieval.removeIf(retrieval2 -> {
                Long longId = Retrieval.getLongId(retrieval2.getMetadata().get("modelId"));
                return Objects.isNull(longId) ? CollectionUtils.isEmpty(modelIds) : !modelIds.contains(longId);
            });
            retrieveQueryResult.setRetrieval((List) retrieval.stream().flatMap(retrieval3 -> {
                List<Long> list = (List) map.get(Retrieval.getLongId(retrieval3.getMetadata().get("modelId")));
                if (CollectionUtils.isEmpty(list)) {
                    HashSet hashSet = new HashSet();
                    hashSet.add(retrieval3);
                    return hashSet.stream();
                }
                HashSet hashSet2 = new HashSet();
                for (Long l : list) {
                    Retrieval retrieval3 = new Retrieval();
                    BeanUtils.copyProperties(retrieval3, retrieval3);
                    retrieval3.getMetadata().putIfAbsent("dataSetId", l + "_");
                    hashSet2.add(retrieval3);
                }
                return hashSet2.stream();
            }).collect(Collectors.toList()));
            return retrieveQueryResult;
        }).filter(retrieveQueryResult2 -> {
            return CollectionUtils.isNotEmpty(retrieveQueryResult2.getRetrieval());
        }).collect(Collectors.toList());
    }
}
