/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.CausalSelfAttention;
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.model.functions.FeedForward;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.DebugSupport;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TransformerBlock {
    private static final Logger logger = LoggerFactory.getLogger(TransformerBlock.class);
    private final AbstractModel model;
    final int layerIndex;
    final Optional<LayerNorm> preAttentionNorm;
    final CausalSelfAttention attention;
    final LayerNorm postAttentionNorm;
    final FeedForward ffBlock;
    final Optional<LayerNorm> postFFNorm;

    public TransformerBlock(AbstractModel model, int layerIndex, LayerNorm preAttentionNorm, CausalSelfAttention attention, LayerNorm postAttentionNorm, FeedForward ffBlock) {
        this.model = model;
        this.layerIndex = layerIndex;
        this.preAttentionNorm = Optional.of(preAttentionNorm);
        this.attention = attention;
        this.postAttentionNorm = postAttentionNorm;
        this.ffBlock = ffBlock;
        this.postFFNorm = Optional.empty();
    }

    public TransformerBlock(AbstractModel model, int layerIndex, CausalSelfAttention attention, LayerNorm postAttentionNorm, FeedForward ffBlock, LayerNorm postFFNorm) {
        this.model = model;
        this.layerIndex = layerIndex;
        this.preAttentionNorm = Optional.empty();
        this.attention = attention;
        this.postAttentionNorm = postAttentionNorm;
        this.ffBlock = ffBlock;
        this.postFFNorm = Optional.of(postFFNorm);
    }

    public AbstractTensor forward(AbstractTensor embedding, int position, KvBufferCache.KvBuffer kvBuffer) {
        return this.forward(embedding, position, kvBuffer, Optional.empty());
    }

    public AbstractTensor forward(AbstractTensor embedding, int position, KvBufferCache.KvBuffer kvBuffer, Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
        AbstractTensor postFF;
        AbstractTensor postAttention;
        DebugSupport.debug("input_emb", embedding, this.layerIndex);
        AbstractTensor lnemb = this.preAttentionNorm.map(ln -> ln.forward(embedding)).orElse(embedding);
        DebugSupport.debug("ln_emb", lnemb, this.layerIndex);
        try (AbstractTensor qlnemb = this.model.maybeQuantize(lnemb);){
            postAttention = this.attention.forward(qlnemb, position, kvBuffer, tensorReducer);
        }
        DebugSupport.debug("post_attn", postAttention, this.layerIndex);
        TensorOperationsProvider.get().accumulate(postAttention, embedding, 0, this.model.c.embeddingLength);
        DebugSupport.debug("post_attn_res", postAttention, this.layerIndex);
        AbstractTensor lnemb2 = this.postAttentionNorm.forward(postAttention);
        DebugSupport.debug("ln_emb2", lnemb2, this.layerIndex);
        try (AbstractTensor qlnemb2 = this.model.maybeQuantize(lnemb2);){
            postFF = this.ffBlock.forward(qlnemb2, tensorReducer);
            DebugSupport.debug("post_ff", postFF, this.layerIndex);
        }
        TensorOperationsProvider.get().accumulate(postFF, postAttention, 0, this.model.c.embeddingLength);
        DebugSupport.debug("post_ff_res", postFF, this.layerIndex);
        if (lnemb != embedding) {
            lnemb.close();
        }
        lnemb2.close();
        postAttention.close();
        return this.postFFNorm.map(ln -> {
            AbstractTensor lnout = ln.forward(postFF);
            DebugSupport.debug("ln_out", lnout, this.layerIndex);
            postFF.close();
            return lnout;
        }).orElse(postFF);
    }
}

