/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor.operations;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.TensorCache;
import com.github.tjake.jlama.tensor.TensorShape;
import com.google.common.base.Preconditions;

public interface TensorOperations {
    public static final ThreadLocal<FloatBufferTensor> scratch = ThreadLocal.withInitial(() -> new FloatBufferTensor(TensorShape.one));

    public String name();

    default public int parallelSplitSize() {
        return 1;
    }

    default public float dotProduct(AbstractTensor a, AbstractTensor b, int limit) {
        return this.dotProduct(a, b, 0, 0, limit);
    }

    default public float dotProduct(AbstractTensor a, AbstractTensor b, int aoffset, int boffset, int limit) {
        FloatBufferTensor r = scratch.get();
        this.batchDotProduct(r, a, b, aoffset, boffset, limit);
        return r.get(0, 0);
    }

    default public void batchDotProduct(AbstractTensor result, AbstractTensor a, AbstractTensor b, int aColumnOffset, int bColumnOffset, int columnLimit) {
        this.batchDotProduct(result, a, b, aColumnOffset, bColumnOffset, columnLimit, 0, 0, b.shape().first());
    }

    public void batchDotProduct(AbstractTensor var1, AbstractTensor var2, AbstractTensor var3, int var4, int var5, int var6, int var7, int var8, int var9);

    default public void dotProductChunk(AbstractTensor result, AbstractTensor a, AbstractTensor b, int columnOffset, int columnLimit, int rowOffset, int rowChunkSize) {
        this.batchDotProduct(result, a, b, columnOffset, columnOffset, columnLimit, 0, rowOffset, rowChunkSize);
    }

    default public void dotProductBatchChunk(AbstractTensor[] result, AbstractTensor a, AbstractTensor[] b, int offset, int limit, int chunkStart, int chunkSize) {
        Preconditions.checkArgument((b[0].dims() == 2 && result.length == b.length ? 1 : 0) != 0);
        for (int j = 0; j < result.length; ++j) {
            this.dotProductChunk(result[j], a, b[j], offset, limit, chunkStart, chunkSize);
        }
    }

    public void accumulate(AbstractTensor var1, AbstractTensor var2, int var3, int var4);

    public void maccumulate(AbstractTensor var1, AbstractTensor var2, int var3, int var4);

    public void saxpy(float var1, AbstractTensor var2, AbstractTensor var3, int var4, int var5, int var6);

    default public void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int aOffset, int xRowOffset, int batchSize) {
        Preconditions.checkArgument((alpha.shape().last() == x.shape().first() && y.shape().first() == 1 ? 1 : 0) != 0);
        int batchLimit = xRowOffset + batchSize;
        int xi = xRowOffset;
        while (xi < batchLimit) {
            this.saxpy(alpha.get(0, aOffset++), x.slice(xi++), y, xoffset, yoffset, limit);
        }
    }

    public void scale(float var1, AbstractTensor var2, int var3, int var4);

    default public AbstractTensor quantize(AbstractTensor t, DType qtype, int offset, int length) {
        AbstractTensor t2 = TensorCache.instance.get(t.dType(), t.shape());
        t2.copyFrom(t, offset, offset, length);
        return t2;
    }

    default public float sum(AbstractTensor a) {
        float sum = 0.0f;
        int[] cursor = new int[a.dims()];
        while (a.iterate(cursor)) {
            sum += a.get(cursor);
        }
        return sum;
    }
}

