package dev.langchain4j.model.embedding.onnx;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.ValidationUtils;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/model/embedding/onnx/OnnxBertBiEncoder.class */
public class OnnxBertBiEncoder {
    private static final int MAX_SEQUENCE_LENGTH = 510;
    private final OrtEnvironment environment;
    private final OrtSession session;
    private final Set<String> expectedInputs;
    private final HuggingFaceTokenizer tokenizer;
    private final PoolingMode poolingMode;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dev/langchain4j/model/embedding/onnx/OnnxBertBiEncoder$EmbeddingAndTokenCount.class */
    public static class EmbeddingAndTokenCount {
        float[] embedding;
        int tokenCount;

        EmbeddingAndTokenCount(float[] fArr, int i) {
            this.embedding = fArr;
            this.tokenCount = i;
        }
    }

    public OnnxBertBiEncoder(InputStream inputStream, InputStream inputStream2, PoolingMode poolingMode) {
        try {
            this.environment = OrtEnvironment.getEnvironment();
            this.session = this.environment.createSession(loadModel(inputStream));
            this.expectedInputs = this.session.getInputNames();
            this.tokenizer = HuggingFaceTokenizer.newInstance(inputStream2, Collections.singletonMap("padding", "false"));
            this.poolingMode = (PoolingMode) ValidationUtils.ensureNotNull(poolingMode, "poolingMode");
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public OnnxBertBiEncoder(OrtEnvironment ortEnvironment, OrtSession ortSession, InputStream inputStream, PoolingMode poolingMode) {
        try {
            this.environment = ortEnvironment;
            this.session = ortSession;
            this.expectedInputs = ortSession.getInputNames();
            this.tokenizer = HuggingFaceTokenizer.newInstance(inputStream, Collections.singletonMap("padding", "false"));
            this.poolingMode = (PoolingMode) ValidationUtils.ensureNotNull(poolingMode, "poolingMode");
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public EmbeddingAndTokenCount embed(String str) {
        List list = this.tokenizer.tokenize(str);
        List<List<String>> partition = partition(list, MAX_SEQUENCE_LENGTH);
        ArrayList arrayList = new ArrayList();
        Iterator<List<String>> it = partition.iterator();
        while (it.hasNext()) {
            try {
                OrtSession.Result encode = encode(it.next());
                try {
                    arrayList.add(toEmbedding(encode));
                    if (encode != null) {
                        encode.close();
                    }
                } finally {
                }
            } catch (OrtException e) {
                throw new RuntimeException((Throwable) e);
            }
        }
        return new EmbeddingAndTokenCount(normalize(weightedAverage(arrayList, (List) partition.stream().map((v0) -> {
            return v0.size();
        }).collect(Collectors.toList()))), list.size());
    }

    static List<List<String>> partition(List<String> list, int i) {
        ArrayList arrayList = new ArrayList();
        int i2 = 1;
        while (true) {
            int i3 = i2;
            if (i3 >= list.size() - 1) {
                return arrayList;
            }
            int i4 = i3 + i;
            if (i4 >= list.size() - 1) {
                i4 = list.size() - 1;
            } else {
                while (list.get(i4).startsWith("##")) {
                    i4--;
                }
            }
            arrayList.add(list.subList(i3, i4));
            i2 = i4;
        }
    }

    private OrtSession.Result encode(List<String> list) throws OrtException {
        Encoding encode = this.tokenizer.encode(toText(list), true, false);
        long[] ids = encode.getIds();
        long[] attentionMask = encode.getAttentionMask();
        long[] typeIds = encode.getTypeIds();
        long[] jArr = {1, ids.length};
        OnnxTensor createTensor = OnnxTensor.createTensor(this.environment, LongBuffer.wrap(ids), jArr);
        try {
            OnnxTensor createTensor2 = OnnxTensor.createTensor(this.environment, LongBuffer.wrap(attentionMask), jArr);
            try {
                OnnxTensor createTensor3 = OnnxTensor.createTensor(this.environment, LongBuffer.wrap(typeIds), jArr);
                try {
                    HashMap hashMap = new HashMap();
                    hashMap.put("input_ids", createTensor);
                    hashMap.put("attention_mask", createTensor2);
                    if (this.expectedInputs.contains("token_type_ids")) {
                        hashMap.put("token_type_ids", createTensor3);
                    }
                    OrtSession.Result run = this.session.run(hashMap);
                    if (createTensor3 != null) {
                        createTensor3.close();
                    }
                    if (createTensor2 != null) {
                        createTensor2.close();
                    }
                    if (createTensor != null) {
                        createTensor.close();
                    }
                    return run;
                } catch (Throwable th) {
                    if (createTensor3 != null) {
                        try {
                            createTensor3.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (Throwable th3) {
                if (createTensor2 != null) {
                    try {
                        createTensor2.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                }
                throw th3;
            }
        } catch (Throwable th5) {
            if (createTensor != null) {
                try {
                    createTensor.close();
                } catch (Throwable th6) {
                    th5.addSuppressed(th6);
                }
            }
            throw th5;
        }
    }

    private String toText(List<String> list) {
        String buildSentence = this.tokenizer.buildSentence(list);
        LinkedList linkedList = new LinkedList(this.tokenizer.tokenize(buildSentence));
        linkedList.remove(0);
        linkedList.remove(linkedList.size() - 1);
        return linkedList.equals(list) ? buildSentence : String.join("", list);
    }

    private float[] toEmbedding(OrtSession.Result result) throws OrtException {
        return pool(((float[][][]) result.get(0).getValue())[0]);
    }

    private float[] pool(float[][] fArr) {
        switch (this.poolingMode) {
            case CLS:
                return clsPool(fArr);
            case MEAN:
                return meanPool(fArr);
            default:
                throw Exceptions.illegalArgument("Unknown pooling mode: " + String.valueOf(this.poolingMode), new Object[0]);
        }
    }

    private static float[] clsPool(float[][] fArr) {
        return fArr[0];
    }

    private static float[] meanPool(float[][] fArr) {
        int length = fArr.length;
        int length2 = fArr[0].length;
        float[] fArr2 = new float[length2];
        for (float[] fArr3 : fArr) {
            for (int i = 0; i < length2; i++) {
                int i2 = i;
                fArr2[i2] = fArr2[i2] + fArr3[i];
            }
        }
        for (int i3 = 0; i3 < length2; i3++) {
            int i4 = i3;
            fArr2[i4] = fArr2[i4] / length;
        }
        return fArr2;
    }

    private float[] weightedAverage(List<float[]> list, List<Integer> list2) {
        if (list.size() == 1) {
            return list.get(0);
        }
        int length = list.get(0).length;
        float[] fArr = new float[length];
        int i = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            int intValue = list2.get(i2).intValue();
            i += intValue;
            for (int i3 = 0; i3 < length; i3++) {
                int i4 = i3;
                fArr[i4] = fArr[i4] + (list.get(i2)[i3] * intValue);
            }
        }
        for (int i5 = 0; i5 < length; i5++) {
            int i6 = i5;
            fArr[i6] = fArr[i6] / i;
        }
        return fArr;
    }

    private static float[] normalize(float[] fArr) {
        float f = 0.0f;
        for (float f2 : fArr) {
            f += f2 * f2;
        }
        float sqrt = (float) Math.sqrt(f);
        float[] fArr2 = new float[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            fArr2[i] = fArr[i] / sqrt;
        }
        return fArr2;
    }

    int countTokens(String str) {
        return this.tokenizer.tokenize(str).size();
    }

    private byte[] loadModel(InputStream inputStream) {
        try {
            try {
                ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                try {
                    byte[] bArr = new byte[1024];
                    while (true) {
                        int read = inputStream.read(bArr, 0, bArr.length);
                        if (read == -1) {
                            break;
                        }
                        byteArrayOutputStream.write(bArr, 0, read);
                    }
                    byteArrayOutputStream.flush();
                    byte[] byteArray = byteArrayOutputStream.toByteArray();
                    byteArrayOutputStream.close();
                    if (inputStream != null) {
                        inputStream.close();
                    }
                    return byteArray;
                } catch (Throwable th) {
                    try {
                        byteArrayOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                    throw th;
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
