/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.safetensors;

import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.tensor.TensorCache;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.io.Files;
import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class Config {
    public final int contextLength;
    public final int embeddingLength;
    public final int attentionLength;
    public final int hiddenLength;
    public final int numberOfHeads;
    public final int numberOfKeyValueHeads;
    public final int headSize;
    public final ActivationFunction.Type activationFunction;
    public final int headGroupSize;
    public final int kvLength;
    public final boolean isGQA;
    public final int numberOfLayers;
    public final float layerNormEps;
    public final int vocabularySize;
    public final int bosToken;
    public final List<Integer> eosTokens;
    public final Optional<float[][]> ropeFreqs;
    public final Optional<BiMap<String, Integer>> classifcationLabels;
    private volatile DistributedContext dctx;
    private volatile File workingDirectory;
    public final TensorCache tensorCache;

    public Config(int contextLength, int embeddingLength, int hiddenLength, int numberOfHeads, int numberOfKeyValueHeads, int numberOfLayers, float layerNormEps, int vocabularySize, int bosToken, List<Integer> eosToken, ActivationFunction.Type activationFunction, Double ropeFreqsTheta, Double ropeScalingFactor) {
        this(contextLength, embeddingLength, hiddenLength, numberOfHeads, numberOfKeyValueHeads, numberOfLayers, layerNormEps, vocabularySize, bosToken, eosToken, activationFunction, ropeFreqsTheta, ropeScalingFactor, null, embeddingLength / numberOfHeads);
    }

    public Config(int contextLength, int embeddingLength, int hiddenLength, int numberOfHeads, int numberOfKeyValueHeads, int numberOfLayers, float layerNormEps, int vocabularySize, int bosToken, List<Integer> eosToken, ActivationFunction.Type activationFunction, Double ropeFreqsTheta, Double ropeScalingFactor, Map<String, Integer> classifcationLabels) {
        this(contextLength, embeddingLength, hiddenLength, numberOfHeads, numberOfKeyValueHeads, numberOfLayers, layerNormEps, vocabularySize, bosToken, eosToken, activationFunction, ropeFreqsTheta, ropeScalingFactor, classifcationLabels, embeddingLength / numberOfHeads);
    }

    public Config(int contextLength, int embeddingLength, int hiddenLength, int numberOfHeads, int numberOfKeyValueHeads, int numberOfLayers, float layerNormEps, int vocabularySize, int bosToken, List<Integer> eosTokens, ActivationFunction.Type activationFunction, Double ropeFreqsTheta, Double ropeScalingFactor, Map<String, Integer> classifcationLabels, Integer headSize) {
        this.contextLength = contextLength;
        this.attentionLength = numberOfHeads * headSize;
        this.embeddingLength = embeddingLength;
        this.hiddenLength = hiddenLength;
        this.numberOfHeads = numberOfHeads;
        this.numberOfKeyValueHeads = numberOfKeyValueHeads;
        this.numberOfLayers = numberOfLayers;
        this.layerNormEps = layerNormEps;
        this.vocabularySize = vocabularySize;
        this.bosToken = bosToken;
        this.eosTokens = eosTokens;
        this.tensorCache = TensorCache.instance;
        this.headSize = headSize;
        this.headGroupSize = numberOfHeads / numberOfKeyValueHeads;
        this.kvLength = numberOfKeyValueHeads * headSize;
        this.isGQA = numberOfKeyValueHeads < numberOfHeads;
        this.activationFunction = activationFunction;
        this.ropeFreqs = ropeFreqsTheta == null ? Optional.empty() : Optional.of(VectorMath.precomputeFreqsCis(headSize, contextLength, ropeFreqsTheta, ropeScalingFactor == null ? 1.0 : ropeScalingFactor));
        this.classifcationLabels = classifcationLabels == null ? Optional.empty() : Optional.of(ImmutableBiMap.copyOf(classifcationLabels));
        this.dctx = DistributedContext.builder(this).build();
    }

    public void setDistributedContext(DistributedContext dctx) {
        this.dctx = dctx;
    }

    public void setWorkingDirectory(File workingDirectory) {
        if (workingDirectory == null) {
            this.workingDirectory = Files.createTempDir();
            this.workingDirectory.deleteOnExit();
        } else {
            Preconditions.checkArgument((boolean)workingDirectory.isDirectory());
            this.workingDirectory = workingDirectory;
        }
    }

    public Optional<File> workingDirectory() {
        return Optional.ofNullable(this.workingDirectory);
    }

    public DistributedContext dctx() {
        return this.dctx;
    }

    public int maybeMapToGroupHead(int head) {
        if (!this.isGQA) {
            return head;
        }
        return Math.floorDiv(head, this.headGroupSize);
    }

    public boolean isClassifier() {
        return this.classifcationLabels.isPresent();
    }
}

