package dev.langchain4j.store.embedding.inmemory;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.spi.ServiceHelper;
import dev.langchain4j.spi.store.embedding.inmemory.InMemoryEmbeddingStoreJsonCodecFactory;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* loaded from: input_file:dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.class */
public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
    public Set<Entry<Embedded>> entries = new CopyOnWriteArraySet();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore$Entry.class */
    public static class Entry<Embedded> {
        String id;
        Embedding embedding;
        Embedded embedded;

        Entry(String str, Embedding embedding) {
            this(str, embedding, null);
        }

        Entry(String str, Embedding embedding, Embedded embedded) {
            this.id = ValidationUtils.ensureNotBlank(str, "id");
            this.embedding = (Embedding) ValidationUtils.ensureNotNull(embedding, "embedding");
            this.embedded = embedded;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            Entry entry = (Entry) obj;
            return Objects.equals(this.id, entry.id) && Objects.equals(this.embedding, entry.embedding) && Objects.equals(this.embedded, entry.embedded);
        }

        public int hashCode() {
            return Objects.hash(this.id, this.embedding, this.embedded);
        }
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        add(str, embedding, null);
    }

    public String add(Embedding embedding, Embedded embedded) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding, embedded);
        return randomUUID;
    }

    public void add(String str, Embedding embedding, Embedded embedded) {
        this.entries.add(new Entry<>(str, embedding, embedded));
    }

    private List<String> add(List<Entry<Embedded>> list) {
        this.entries.addAll(list);
        return (List) list.stream().map(entry -> {
            return entry.id;
        }).collect(Collectors.toList());
    }

    public List<String> addAll(List<Embedding> list) {
        return add((List) list.stream().map(embedding -> {
            return new Entry(Utils.randomUUID(), embedding);
        }).collect(Collectors.toList()));
    }

    public List<String> addAll(List<Embedding> list, List<Embedded> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
        }
        return add((List) IntStream.range(0, list.size()).mapToObj(i -> {
            return new Entry(Utils.randomUUID(), (Embedding) list.get(i), list2.get(i));
        }).collect(Collectors.toList()));
    }

    public void removeAll(Collection<String> collection) {
        ValidationUtils.ensureNotEmpty(collection, "ids");
        this.entries.removeIf(entry -> {
            return collection.contains(entry.id);
        });
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull(filter, "filter");
        this.entries.removeIf(entry -> {
            if (entry.embedded instanceof TextSegment) {
                return filter.test(((TextSegment) entry.embedded).metadata());
            }
            if (entry.embedded == null) {
                return false;
            }
            throw new UnsupportedOperationException("Not supported yet.");
        });
    }

    public void removeAll() {
        this.entries.clear();
    }

    public EmbeddingSearchResult<Embedded> search(EmbeddingSearchRequest embeddingSearchRequest) {
        Comparator comparingDouble = Comparator.comparingDouble((v0) -> {
            return v0.score();
        });
        PriorityQueue priorityQueue = new PriorityQueue(comparingDouble);
        Filter filter = embeddingSearchRequest.filter();
        for (Entry<Embedded> entry : this.entries) {
            if (filter == null || !(entry.embedded instanceof TextSegment) || filter.test(((TextSegment) entry.embedded).metadata())) {
                double fromCosineSimilarity = RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(entry.embedding, embeddingSearchRequest.queryEmbedding()));
                if (fromCosineSimilarity >= embeddingSearchRequest.minScore()) {
                    priorityQueue.add(new EmbeddingMatch(Double.valueOf(fromCosineSimilarity), entry.id, entry.embedding, entry.embedded));
                    if (priorityQueue.size() > embeddingSearchRequest.maxResults()) {
                        priorityQueue.poll();
                    }
                }
            }
        }
        ArrayList arrayList = new ArrayList(priorityQueue);
        arrayList.sort(comparingDouble);
        Collections.reverse(arrayList);
        return new EmbeddingSearchResult<>(arrayList);
    }

    public String serializeToJson() {
        return loadCodec().toJson(this);
    }

    public void serializeToFile(Path path) {
        try {
            Files.write(path, serializeToJson().getBytes(), StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void serializeToFile(String str) {
        serializeToFile(Paths.get(str, new String[0]));
    }

    public static InMemoryEmbeddingStore<TextSegment> fromJson(String str) {
        return loadCodec().fromJson(str);
    }

    public static InMemoryEmbeddingStore<TextSegment> fromFile(Path path) {
        try {
            return fromJson(new String(Files.readAllBytes(path)));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static InMemoryEmbeddingStore<TextSegment> fromFile(String str) {
        return fromFile(Paths.get(str, new String[0]));
    }

    private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
        Iterator it = ServiceHelper.loadFactories(InMemoryEmbeddingStoreJsonCodecFactory.class).iterator();
        return it.hasNext() ? ((InMemoryEmbeddingStoreJsonCodecFactory) it.next()).create() : new GsonInMemoryEmbeddingStoreJsonCodec();
    }
}
