/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model.functions;

import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.TensorCache;
import com.github.tjake.jlama.tensor.TensorShape;
import com.google.common.base.Preconditions;

public interface EmbedInput {
    public AbstractTensor inputTokenToEmbedding(int var1, int var2);

    default public AbstractTensor batchInputsToEmbeddings(int[] inputTokens, int startPos) {
        Preconditions.checkArgument((inputTokens.length > 0 ? 1 : 0) != 0);
        AbstractTensor t = this.inputTokenToEmbedding(inputTokens[0], startPos);
        if (inputTokens.length == 1) {
            return t;
        }
        TensorShape tbs = TensorShape.of(inputTokens.length, t.shape().last());
        AbstractTensor tb = TensorCache.instance.get(t.dType(), tbs);
        tb.copyFrom(t, 0, 0, t.shape().last());
        t.close();
        VectorMath.pfor(1, inputTokens.length, i -> {
            AbstractTensor ti = this.inputTokenToEmbedding(inputTokens[i], startPos + i);
            tb.copyFrom(ti, 0, i * ti.shape().last(), ti.shape().last());
            ti.close();
        });
        return tb;
    }
}

