/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.common.service.impl;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class EmbeddingServiceImpl
implements EmbeddingService {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingServiceImpl.class);
    @Autowired
    private EmbeddingStoreFactory embeddingStoreFactory;
    @Autowired
    private EmbeddingModel embeddingModel;
    private Cache<String, Boolean> cache = CacheBuilder.newBuilder().maximumSize(10000L).expireAfterWrite(10L, TimeUnit.HOURS).build();

    @Override
    public void addQuery(String collectionName, List<TextSegment> queries) {
        EmbeddingStore<TextSegment> embeddingStore = this.embeddingStoreFactory.create(collectionName);
        for (TextSegment query : queries) {
            String question = query.text();
            try {
                Embedding embedding = (Embedding)this.embeddingModel.embed(question).content();
                boolean existSegment = this.existSegment(embeddingStore, query, embedding);
                if (existSegment) continue;
                embeddingStore.add(embedding, (Object)query);
            }
            catch (Exception e) {
                log.error("embeddingModel embed error question: {}, embeddingStore: {}", new Object[]{question, embeddingStore.getClass().getSimpleName(), e});
            }
        }
    }

    private boolean existSegment(EmbeddingStore embeddingStore, TextSegment query, Embedding embedding) {
        String queryId = TextSegmentConvert.getQueryId(query);
        if (queryId == null) {
            return false;
        }
        Boolean cachedResult = (Boolean)this.cache.getIfPresent((Object)queryId);
        if (cachedResult != null) {
            return cachedResult;
        }
        HashMap<String, String> filterCondition = new HashMap<String, String>();
        filterCondition.put("queryId", queryId);
        Filter filter = EmbeddingServiceImpl.createCombinedFilter(filterCondition);
        EmbeddingSearchRequest request = EmbeddingSearchRequest.builder().queryEmbedding(embedding).filter(filter).maxResults(Integer.valueOf(1)).build();
        EmbeddingSearchResult result = embeddingStore.search(request);
        List relevant = result.matches();
        boolean exists = CollectionUtils.isNotEmpty((Collection)relevant);
        this.cache.put((Object)queryId, (Object)exists);
        return exists;
    }

    @Override
    public void deleteQuery(String collectionName, List<TextSegment> queries) {
        EmbeddingStore<TextSegment> embeddingStore = this.embeddingStoreFactory.create(collectionName);
        try {
            if (embeddingStore instanceof InMemoryEmbeddingStore) {
                InMemoryEmbeddingStore inMemoryEmbeddingStore = (InMemoryEmbeddingStore)embeddingStore;
                List queryIds = queries.stream().map(textSegment -> TextSegmentConvert.getQueryId(textSegment)).filter(Objects::nonNull).collect(Collectors.toList());
                if (CollectionUtils.isNotEmpty(queryIds)) {
                    MetadataFilterBuilder filterBuilder = new MetadataFilterBuilder("queryId");
                    Filter filter = filterBuilder.isIn(queryIds);
                    inMemoryEmbeddingStore.removeAll(filter);
                }
            }
        }
        catch (Exception e) {
            log.error("deleteQuery error,collectionName:{},queries:{}", (Object)collectionName, queries);
        }
    }

    @Override
    public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
        ArrayList<RetrieveQueryResult> results = new ArrayList<RetrieveQueryResult>();
        EmbeddingStore<TextSegment> embeddingStore = this.embeddingStoreFactory.create(collectionName);
        List<String> queryTextsList = retrieveQuery.getQueryTextsList();
        Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
        for (String queryText : queryTextsList) {
            Embedding embeddedText = (Embedding)this.embeddingModel.embed(queryText).content();
            Filter filter = EmbeddingServiceImpl.createCombinedFilter(filterCondition);
            EmbeddingSearchRequest request = EmbeddingSearchRequest.builder().queryEmbedding(embeddedText).filter(filter).maxResults(Integer.valueOf(num)).build();
            EmbeddingSearchResult result = embeddingStore.search(request);
            List relevant = result.matches();
            RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
            retrieveQueryResult.setQuery(queryText);
            ArrayList<Retrieval> retrievals = new ArrayList();
            for (EmbeddingMatch embeddingMatch : relevant) {
                Retrieval retrieval = new Retrieval();
                TextSegment embedded = (TextSegment)embeddingMatch.embedded();
                retrieval.setDistance(1.0 - embeddingMatch.score());
                retrieval.setId(TextSegmentConvert.getQueryId(embedded));
                retrieval.setQuery(embedded.text());
                HashMap<String, Object> metadata = new HashMap<String, Object>();
                if (Objects.nonNull(embedded) && MapUtils.isNotEmpty((Map)embedded.metadata().toMap())) {
                    metadata.putAll(embedded.metadata().toMap());
                }
                retrieval.setMetadata(metadata);
                retrievals.add(retrieval);
            }
            retrievals = retrievals.stream().sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed()).limit(num).collect(Collectors.toList());
            retrieveQueryResult.setRetrieval(retrievals);
            results.add(retrieveQueryResult);
        }
        return results;
    }

    private static Filter createCombinedFilter(Map<String, String> map) {
        IsEqualTo result = null;
        if (MapUtils.isEmpty(map)) {
            return null;
        }
        for (Map.Entry<String, String> entry : map.entrySet()) {
            IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), (Object)entry.getValue());
            result = result == null ? isEqualTo : Filter.and((Filter)result, (Filter)isEqualTo);
        }
        return result;
    }
}

