/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model.llama;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.CausalSelfAttention;
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.model.MLPBlock;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.RMSNorm;
import com.github.tjake.jlama.model.TransformerBlock;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.util.Optional;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LlamaModel
extends AbstractModel {
    private static final Logger logger = LoggerFactory.getLogger(LlamaModel.class);

    public LlamaModel(Config config, WeightLoader weights, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional<DType> modelQType) {
        super(AbstractModel.InferenceType.FULL_GENERATION, config, weights, tokenizer, workingDType, workingQType, modelQType);
    }

    public LlamaModel(AbstractModel.InferenceType inferenceType, Config config, WeightLoader weights, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional<DType> modelQType) {
        super(inferenceType, config, weights, tokenizer, workingDType, workingQType, modelQType);
    }

    @Override
    public ModelSupport.ModelType getModelType() {
        return ModelSupport.ModelType.LLAMA;
    }

    @Override
    protected EmbedInput loadInputWeights() {
        AbstractTensor wte = this.weights.load("model.embed_tokens.weight").quantize(this.workingDType);
        return (inputToken, position) -> {
            AbstractTensor embedding = this.makeDenseTensor(1, this.c.embeddingLength);
            AbstractTensor at = wte.slice(true, inputToken);
            if (wte.dType() != embedding.dType()) {
                at = TensorOperationsProvider.get().quantize(at, embedding.dType(), 0, this.c.embeddingLength);
            }
            embedding.copyFrom(at, 0, 0, this.c.embeddingLength);
            return embedding;
        };
    }

    @Override
    protected TransformerBlock[] loadTransformerBlockWeights() {
        DType qType = this.modelQType.orElse(this.modelDType);
        if (qType != this.modelDType) {
            logger.info("Quantizing model with {} - Please hold...", (Object)qType);
        }
        TransformerBlock[] transformerBlocks = new TransformerBlock[this.c.dctx().numberOfLayers];
        IntStream.range(this.c.dctx().layerStart, this.c.dctx().layerEnd).parallel().forEach(i -> {
            String base = "model.layers." + i + ".";
            String prefix = base + "self_attn.";
            CausalSelfAttention attention = new CausalSelfAttention(this, i, this.weights.load(prefix + "q_proj.weight", this.c.dctx(), true, false).quantize(qType), this.weights.load(prefix + "k_proj.weight", this.c.dctx(), true, false).quantize(qType), this.weights.load(prefix + "v_proj.weight", this.c.dctx(), true, false).quantize(qType), this.weights.load(prefix + "o_proj.weight", this.c.dctx(), false, true).quantize(qType));
            prefix = base + "mlp.";
            MLPBlock mlp = new MLPBlock(this, this.c.activationFunction, this.weights.load(prefix + "gate_proj.weight", this.c.dctx(), true, false).quantize(qType), this.weights.load(prefix + "down_proj.weight", this.c.dctx(), false, true).quantize(qType), this.weights.load(prefix + "up_proj.weight", this.c.dctx(), true, false).quantize(qType));
            transformerBlocks[i] = new TransformerBlock((AbstractModel)this, i, new RMSNorm(this, this.weights.load(base + "input_layernorm.weight").quantize(qType)), attention, new RMSNorm(this, this.weights.load(base + "post_attention_layernorm.weight").quantize(qType)), mlp);
        });
        return transformerBlocks;
    }

    @Override
    protected SampleOutput loadOutputWeights() {
        DType qType = this.modelQType.orElse(this.modelDType);
        final RMSNorm outputLayerNorm = new RMSNorm(this, this.weights.load("model.norm.weight").quantize(qType));
        final AbstractTensor classificationWeights = this.weights.load("lm_head.weight").quantize(this.workingDType);
        return new SampleOutput(){

            @Override
            public LayerNorm getOutputLayerNorm() {
                return outputLayerNorm;
            }

            @Override
            public AbstractTensor getOutputLogitsWeights() {
                return classificationWeights;
            }
        };
    }

    @Override
    protected AbstractTensor maybeQuantize(AbstractTensor t) {
        Preconditions.checkArgument((t.dims() == 2 ? 1 : 0) != 0, (Object)"Unexpected shape");
        if (t.dType() == this.workingQType) {
            return super.maybeQuantize(t);
        }
        return TensorOperationsProvider.get().quantize(t, this.workingQType, 0, Ints.checkedCast((long)t.shape().last()));
    }
}

