/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.supersonic.common.service.impl;

import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.collections.MapUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class EmbeddingServiceImpl
implements EmbeddingService {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingServiceImpl.class);
    @Autowired
    private EmbeddingStoreFactory embeddingStoreFactory;
    @Autowired
    private EmbeddingModel embeddingModel;

    @Override
    public void addQuery(String collectionName, List<TextSegment> queries) {
        EmbeddingStore embeddingStore = this.embeddingStoreFactory.create(collectionName);
        for (TextSegment query : queries) {
            String question = query.text();
            try {
                Embedding embedding = (Embedding)this.embeddingModel.embed(question).content();
                embeddingStore.add(embedding, (Object)query);
            }
            catch (Exception e) {
                log.error("embeddingModel embed error question: {}, embeddingStore: {}", new Object[]{question, embeddingStore.getClass().getSimpleName(), e});
            }
        }
    }

    @Override
    public void deleteQuery(String collectionName, List<TextSegment> queries) {
    }

    @Override
    public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
        ArrayList<RetrieveQueryResult> results = new ArrayList<RetrieveQueryResult>();
        EmbeddingStore embeddingStore = this.embeddingStoreFactory.create(collectionName);
        List<String> queryTextsList = retrieveQuery.getQueryTextsList();
        Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
        for (String queryText : queryTextsList) {
            Embedding embeddedText = (Embedding)this.embeddingModel.embed(queryText).content();
            Filter filter = EmbeddingServiceImpl.createCombinedFilter(filterCondition);
            EmbeddingSearchRequest request = EmbeddingSearchRequest.builder().queryEmbedding(embeddedText).filter(filter).maxResults(Integer.valueOf(num)).build();
            EmbeddingSearchResult result = embeddingStore.search(request);
            List relevant = result.matches();
            RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
            retrieveQueryResult.setQuery(queryText);
            ArrayList<Retrieval> retrievals = new ArrayList();
            for (EmbeddingMatch embeddingMatch : relevant) {
                Retrieval retrieval = new Retrieval();
                TextSegment embedded = (TextSegment)embeddingMatch.embedded();
                retrieval.setDistance(1.0 - embeddingMatch.score());
                retrieval.setId(TextSegmentConvert.getQueryId(embedded));
                retrieval.setQuery(embedded.text());
                HashMap<String, Object> metadata = new HashMap<String, Object>();
                if (Objects.nonNull(embedded) && MapUtils.isNotEmpty((Map)embedded.metadata().toMap())) {
                    metadata.putAll(embedded.metadata().toMap());
                }
                retrieval.setMetadata(metadata);
                retrievals.add(retrieval);
            }
            retrievals = retrievals.stream().sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed()).limit(num).collect(Collectors.toList());
            retrieveQueryResult.setRetrieval(retrievals);
            results.add(retrieveQueryResult);
        }
        return results;
    }

    private static Filter createCombinedFilter(Map<String, String> map) {
        IsEqualTo result = null;
        if (MapUtils.isEmpty(map)) {
            return null;
        }
        for (Map.Entry<String, String> entry : map.entrySet()) {
            IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), (Object)entry.getValue());
            result = result == null ? isEqualTo : Filter.and((Filter)result, (Filter)isEqualTo);
        }
        return result;
    }
}

