package com.tencent.supersonic.common.service.impl;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.provider.ModelProvider;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactoryProvider;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;

@Service
/* loaded from: input_file:com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.class */
public class EmbeddingServiceImpl implements EmbeddingService {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingServiceImpl.class);
    private Cache<String, Boolean> cache = CacheBuilder.newBuilder().maximumSize(10000).expireAfterWrite(10, TimeUnit.HOURS).build();

    @Override // com.tencent.supersonic.common.service.EmbeddingService
    public void addQuery(String str, List<TextSegment> list) {
        EmbeddingStore<TextSegment> create = EmbeddingStoreFactoryProvider.getFactory().create(str);
        for (TextSegment textSegment : list) {
            String text = textSegment.text();
            try {
                Embedding embedding = (Embedding) ModelProvider.getEmbeddingModel().embed(text).content();
                if (!existSegment(create, textSegment, embedding)) {
                    create.add(embedding, textSegment);
                }
            } catch (Exception e) {
                log.error("embeddingModel embed error question: {}, embeddingStore: {}", new Object[]{text, create.getClass().getSimpleName(), e});
            }
        }
    }

    private boolean existSegment(EmbeddingStore embeddingStore, TextSegment textSegment, Embedding embedding) {
        String queryId = TextSegmentConvert.getQueryId(textSegment);
        if (queryId == null) {
            return false;
        }
        Boolean bool = (Boolean) this.cache.getIfPresent(queryId);
        if (bool != null) {
            return bool.booleanValue();
        }
        HashMap hashMap = new HashMap();
        hashMap.put(TextSegmentConvert.QUERY_ID, queryId);
        boolean isNotEmpty = CollectionUtils.isNotEmpty(embeddingStore.search(EmbeddingSearchRequest.builder().queryEmbedding(embedding).filter(createCombinedFilter(hashMap)).maxResults(1).build()).matches());
        this.cache.put(queryId, Boolean.valueOf(isNotEmpty));
        return isNotEmpty;
    }

    @Override // com.tencent.supersonic.common.service.EmbeddingService
    public void deleteQuery(String str, List<TextSegment> list) {
        EmbeddingStore<TextSegment> create = EmbeddingStoreFactoryProvider.getFactory().create(str);
        try {
            if (create instanceof InMemoryEmbeddingStore) {
                InMemoryEmbeddingStore inMemoryEmbeddingStore = (InMemoryEmbeddingStore) create;
                List list2 = (List) list.stream().map(textSegment -> {
                    return TextSegmentConvert.getQueryId(textSegment);
                }).filter((v0) -> {
                    return Objects.nonNull(v0);
                }).collect(Collectors.toList());
                if (CollectionUtils.isNotEmpty(list2)) {
                    inMemoryEmbeddingStore.removeAll(new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID).isIn(list2));
                }
            }
        } catch (Exception e) {
            log.error("deleteQuery error,collectionName:{},queries:{}", str, list);
        }
    }

    @Override // com.tencent.supersonic.common.service.EmbeddingService
    public List<RetrieveQueryResult> retrieveQuery(String str, RetrieveQuery retrieveQuery, int i) {
        ArrayList arrayList = new ArrayList();
        EmbeddingStore<TextSegment> create = EmbeddingStoreFactoryProvider.getFactory().create(str);
        List<String> queryTextsList = retrieveQuery.getQueryTextsList();
        Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
        for (String str2 : queryTextsList) {
            List<EmbeddingMatch> matches = create.search(EmbeddingSearchRequest.builder().queryEmbedding((Embedding) ModelProvider.getEmbeddingModel().embed(str2).content()).filter(createCombinedFilter(filterCondition)).maxResults(Integer.valueOf(i)).build()).matches();
            RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
            retrieveQueryResult.setQuery(str2);
            ArrayList arrayList2 = new ArrayList();
            for (EmbeddingMatch embeddingMatch : matches) {
                Retrieval retrieval = new Retrieval();
                TextSegment textSegment = (TextSegment) embeddingMatch.embedded();
                retrieval.setDistance(1.0d - embeddingMatch.score().doubleValue());
                retrieval.setId(TextSegmentConvert.getQueryId(textSegment));
                retrieval.setQuery(textSegment.text());
                HashMap hashMap = new HashMap();
                if (Objects.nonNull(textSegment) && MapUtils.isNotEmpty(textSegment.metadata().toMap())) {
                    hashMap.putAll(textSegment.metadata().toMap());
                }
                retrieval.setMetadata(hashMap);
                arrayList2.add(retrieval);
            }
            retrieveQueryResult.setRetrieval((List) arrayList2.stream().sorted(Comparator.comparingDouble((v0) -> {
                return v0.getDistance();
            }).reversed()).limit(i).collect(Collectors.toList()));
            arrayList.add(retrieveQueryResult);
        }
        return arrayList;
    }

    private static Filter createCombinedFilter(Map<String, String> map) {
        Filter filter = null;
        if (MapUtils.isEmpty(map)) {
            return null;
        }
        for (Map.Entry<String, String> entry : map.entrySet()) {
            Filter isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue());
            filter = filter == null ? isEqualTo : Filter.and(filter, isEqualTo);
        }
        return filter;
    }
}
