/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model;

import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.DebugSupport;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

public class CausalSelfAttention {
    private final AbstractModel m;
    private final Config c;
    private final int layerIndex;
    private final DistributedContext dctx;
    private final Optional<AbstractTensor> queryAttnBias;
    private final Optional<AbstractTensor> keyAttnBias;
    private final Optional<AbstractTensor> valueAttnBias;
    private final Optional<AbstractTensor> outputProjectionBias;
    final AbstractTensor queryAttnWeights;
    final AbstractTensor keyAttnWeights;
    final AbstractTensor valueAttnWeights;
    private final AbstractTensor outputProjectionWeights;
    private final float attentionScale;
    private final int attentionLength;
    private final AbstractTensor[] qkvResults;
    private final AbstractTensor[] qkvWeights;

    public CausalSelfAttention(AbstractModel m, int layerIndex, AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights, AbstractTensor outputProjectionWeights) {
        this(m, layerIndex, Optional.empty(), Optional.empty(), Optional.empty(), queryAttnWeights, keyAttnWeights, valueAttnWeights, Optional.empty(), outputProjectionWeights);
    }

    public CausalSelfAttention(AbstractModel m, int layerIndex, AbstractTensor queryAttnBias, AbstractTensor keyAttnBias, AbstractTensor valueAttnBias, AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights, AbstractTensor outputProjectionBias, AbstractTensor outputProjectionWeights) {
        this(m, layerIndex, Optional.of(queryAttnBias), Optional.of(keyAttnBias), Optional.of(valueAttnBias), queryAttnWeights, keyAttnWeights, valueAttnWeights, Optional.of(outputProjectionBias), outputProjectionWeights);
    }

    public CausalSelfAttention(AbstractModel m, int layerIndex, Optional<AbstractTensor> queryAttnBias, Optional<AbstractTensor> keyAttnBias, Optional<AbstractTensor> valueAttnBias, AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights, Optional<AbstractTensor> outputProjectionBias, AbstractTensor outputProjectionWeights) {
        this.m = m;
        this.layerIndex = layerIndex;
        this.c = m.c;
        this.dctx = m.c.dctx();
        this.queryAttnBias = queryAttnBias;
        this.keyAttnBias = keyAttnBias;
        this.valueAttnBias = valueAttnBias;
        this.queryAttnWeights = queryAttnWeights;
        this.keyAttnWeights = keyAttnWeights;
        this.valueAttnWeights = valueAttnWeights;
        this.outputProjectionBias = outputProjectionBias;
        this.outputProjectionWeights = outputProjectionWeights;
        this.attentionLength = this.c.numberOfHeads * this.c.headSize;
        this.attentionScale = (float)(1.0 / StrictMath.sqrt(this.c.headSize));
        this.qkvResults = new AbstractTensor[3];
        this.qkvWeights = new AbstractTensor[]{queryAttnWeights, keyAttnWeights, valueAttnWeights};
    }

    /*
     * Exception decompiling
     */
    public AbstractTensor forward(AbstractTensor input, int startPosition, KvBufferCache.KvBuffer kvMem, Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 3 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private /* synthetic */ void lambda$forward$10(AbstractTensor result, AbstractTensor bias) {
        TensorOperationsProvider.get().accumulate(result, bias, 0, this.c.embeddingLength);
    }

    private static /* synthetic */ void lambda$forward$9(AbstractTensor result, Consumer func) {
        func.accept(Collections.singletonList(result));
    }

    private /* synthetic */ void lambda$forward$8(AbstractTensor result, AbstractTensor vq, int chunkStart, int chunkSize) {
        TensorOperationsProvider.get().dotProductChunk(result, vq, this.outputProjectionWeights, this.dctx.attentionSegmentStart, this.dctx.attentionSegmentLength, chunkStart, chunkSize);
    }

    private /* synthetic */ void lambda$forward$7(AbstractTensor query, AbstractTensor[] kvp, int finalPostion, AbstractTensor[] vvp, AbstractTensor value, int h) {
        int xoffset = this.c.maybeMapToGroupHead(h) * this.c.headSize;
        int yoffset = h * this.c.headSize;
        if (yoffset >= query.shape().last()) {
            return;
        }
        try (AbstractTensor attn = this.m.makeDenseTensor(1, kvp[0].shape().first() * kvp.length);){
            int size;
            int offset;
            int len;
            int i;
            for (i = 0; i < kvp.length; ++i) {
                len = kvp[i].shape().first();
                offset = i * len;
                size = i == kvp.length - 1 ? finalPostion + 1 - offset : len;
                TensorOperationsProvider.get().batchDotProduct(attn, query, kvp[i], yoffset, xoffset, this.c.headSize, offset, 0, size);
            }
            TensorOperationsProvider.get().scale(this.attentionScale, attn, 0, finalPostion + 1);
            VectorMath.softMax(attn, 0, finalPostion + 1);
            for (i = 0; i < vvp.length; ++i) {
                len = vvp[i].shape().first();
                offset = i * len;
                size = i == vvp.length - 1 ? finalPostion + 1 - offset : len;
                TensorOperationsProvider.get().saxpy(attn, vvp[i], value, xoffset, yoffset, this.c.headSize, offset, 0, size);
            }
        }
    }

    private /* synthetic */ void lambda$forward$6(int finalPostion, AbstractTensor query, AbstractTensor key, float[][] rf) {
        int headPiece = this.c.headSize / 2;
        int poffset = finalPostion * headPiece;
        if (this.c.isGQA) {
            int offset;
            int h;
            for (h = this.dctx.headStart; h < this.dctx.headEnd && (offset = h * this.c.headSize) < query.shape().last(); ++h) {
                int goffset = this.c.maybeMapToGroupHead(h) * this.c.headSize;
                int i = offset;
                int g = goffset;
                while (i < offset + headPiece) {
                    float q0 = query.get(0, i);
                    float q1 = query.get(0, i + headPiece);
                    float[] f = rf[poffset + g];
                    float fcr = f[0];
                    float fci = f[1];
                    query.set(q0 * fcr - q1 * fci, 0, i);
                    query.set(q0 * fci + q1 * fcr, 0, i + headPiece);
                    ++i;
                    ++g;
                }
            }
            for (h = this.dctx.groupHeadStart; h < this.dctx.groupHeadEnd && (offset = h * this.c.headSize) < key.shape().last(); ++h) {
                for (int i = offset; i < offset + headPiece; ++i) {
                    float k00 = key.get(0, i);
                    float k1 = key.get(0, i + headPiece);
                    float[] f = rf[poffset + i];
                    float fcr = f[0];
                    float fci = f[1];
                    key.set(k00 * fcr - k1 * fci, 0, i);
                    key.set(k00 * fci + k1 * fcr, 0, i + headPiece);
                }
            }
        } else {
            for (int h = this.dctx.headStart; h < this.dctx.headEnd; ++h) {
                int offset;
                for (int i = offset = h * this.c.headSize; i < offset + headPiece; ++i) {
                    float q0 = query.get(0, i);
                    float q1 = query.get(0, i + headPiece);
                    float k00 = key.get(0, i);
                    float k1 = key.get(0, i + headPiece);
                    float[] f = rf[poffset + i];
                    float fcr = f[0];
                    float fci = f[1];
                    query.set(q0 * fcr - q1 * fci, 0, i);
                    query.set(q0 * fci + q1 * fcr, 0, i + headPiece);
                    key.set(k00 * fcr - k1 * fci, 0, i);
                    key.set(k00 * fci + k1 * fcr, 0, i + headPiece);
                }
            }
        }
        DebugSupport.debug("query+rope", query, finalPostion);
        DebugSupport.debug("key+rope", key, finalPostion);
    }

    private /* synthetic */ void lambda$forward$5(AbstractTensor tmpValBatch, AbstractTensor bias) {
        TensorOperationsProvider.get().accumulate(tmpValBatch, bias, this.dctx.kvSegmentStart, this.dctx.kvSegmentLength);
    }

    private /* synthetic */ void lambda$forward$4(AbstractTensor tmpKeyBatch, AbstractTensor bias) {
        TensorOperationsProvider.get().accumulate(tmpKeyBatch, bias, this.dctx.kvSegmentStart, this.dctx.kvSegmentLength);
    }

    private /* synthetic */ void lambda$forward$3(AbstractTensor queryBatch, AbstractTensor bias) {
        TensorOperationsProvider.get().accumulate(queryBatch, bias, this.dctx.attentionSegmentStart, this.dctx.attentionSegmentLength);
    }

    private /* synthetic */ void lambda$forward$2(AbstractTensor input, int chunkStart, int chunkLength) {
        TensorOperationsProvider.get().dotProductBatchChunk(this.qkvResults, input, this.qkvWeights, 0, this.c.embeddingLength, chunkStart, chunkLength);
    }

    private /* synthetic */ void lambda$forward$1(AbstractTensor tmpKeyBatch, AbstractTensor input, AbstractTensor tmpValBatch, int chunkStart, int chunkLength) {
        TensorOperationsProvider.get().dotProductChunk(tmpKeyBatch, input, this.keyAttnWeights, 0, this.c.embeddingLength, chunkStart, chunkLength);
        TensorOperationsProvider.get().dotProductChunk(tmpValBatch, input, this.valueAttnWeights, 0, this.c.embeddingLength, chunkStart, chunkLength);
    }

    private /* synthetic */ void lambda$forward$0(AbstractTensor queryBatch, AbstractTensor input, int chunkStart, int chunkLength) {
        TensorOperationsProvider.get().dotProductChunk(queryBatch, input, this.queryAttnWeights, 0, this.c.embeddingLength, chunkStart, chunkLength);
    }
}

