package com.tencent.supersonic.common.service.impl;

import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.collections.MapUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
/* loaded from: input_file:com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.class */
public class EmbeddingServiceImpl implements EmbeddingService {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingServiceImpl.class);

    @Autowired
    private EmbeddingStoreFactory embeddingStoreFactory;

    @Autowired
    private EmbeddingModel embeddingModel;

    @Override // com.tencent.supersonic.common.service.EmbeddingService
    public void addQuery(String str, List<TextSegment> list) {
        EmbeddingStore create = this.embeddingStoreFactory.create(str);
        for (TextSegment textSegment : list) {
            String text = textSegment.text();
            try {
                create.add((Embedding) this.embeddingModel.embed(text).content(), textSegment);
            } catch (Exception e) {
                log.error("embeddingModel embed error question: {}, embeddingStore: {}", new Object[]{text, create.getClass().getSimpleName(), e});
            }
        }
    }

    @Override // com.tencent.supersonic.common.service.EmbeddingService
    public void deleteQuery(String str, List<TextSegment> list) {
    }

    @Override // com.tencent.supersonic.common.service.EmbeddingService
    public List<RetrieveQueryResult> retrieveQuery(String str, RetrieveQuery retrieveQuery, int i) {
        ArrayList arrayList = new ArrayList();
        EmbeddingStore create = this.embeddingStoreFactory.create(str);
        List<String> queryTextsList = retrieveQuery.getQueryTextsList();
        Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
        for (String str2 : queryTextsList) {
            List<EmbeddingMatch> matches = create.search(EmbeddingSearchRequest.builder().queryEmbedding((Embedding) this.embeddingModel.embed(str2).content()).filter(createCombinedFilter(filterCondition)).maxResults(Integer.valueOf(i)).build()).matches();
            RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
            retrieveQueryResult.setQuery(str2);
            ArrayList arrayList2 = new ArrayList();
            for (EmbeddingMatch embeddingMatch : matches) {
                Retrieval retrieval = new Retrieval();
                TextSegment textSegment = (TextSegment) embeddingMatch.embedded();
                retrieval.setDistance(1.0d - embeddingMatch.score().doubleValue());
                retrieval.setId(TextSegmentConvert.getQueryId(textSegment));
                retrieval.setQuery(textSegment.text());
                HashMap hashMap = new HashMap();
                if (Objects.nonNull(textSegment) && MapUtils.isNotEmpty(textSegment.metadata().toMap())) {
                    hashMap.putAll(textSegment.metadata().toMap());
                }
                retrieval.setMetadata(hashMap);
                arrayList2.add(retrieval);
            }
            retrieveQueryResult.setRetrieval((List) arrayList2.stream().sorted(Comparator.comparingDouble((v0) -> {
                return v0.getDistance();
            }).reversed()).limit(i).collect(Collectors.toList()));
            arrayList.add(retrieveQueryResult);
        }
        return arrayList;
    }

    private static Filter createCombinedFilter(Map<String, String> map) {
        Filter filter = null;
        if (MapUtils.isEmpty(map)) {
            return null;
        }
        for (Map.Entry<String, String> entry : map.entrySet()) {
            Filter isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue());
            filter = filter == null ? isEqualTo : Filter.and(filter, isEqualTo);
        }
        return filter;
    }
}
