/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.safetensors;

import com.github.tjake.jlama.math.FloatConversions;
import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.TensorInfo;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.BFloat16BufferTensor;
import com.github.tjake.jlama.tensor.Float16BufferTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.util.Pair;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.ShortBuffer;
import java.util.EnumMap;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Weights
implements WeightLoader {
    private static final Logger logger = LoggerFactory.getLogger(Weights.class);
    private final Map<String, String> metadata;
    private final Map<String, TensorInfo> tensorInfoMap;
    private final ByteBuffer bytes;
    private final DType majorityDType;
    private final Optional<WeightLoader> parent;

    Weights(Map<String, String> metadata, Map<String, TensorInfo> tensorInfoMap, ByteBuffer bytes, Optional<WeightLoader> parent) {
        this.metadata = ImmutableMap.copyOf(metadata);
        this.tensorInfoMap = ImmutableMap.copyOf(tensorInfoMap);
        this.bytes = bytes.duplicate();
        this.majorityDType = Weights.findDType(tensorInfoMap);
        this.parent = parent;
    }

    public static DType findDType(Map<String, TensorInfo> tensorInfoMap) {
        EnumMap<DType, Integer> counts = new EnumMap<DType, Integer>(DType.class);
        for (Map.Entry<String, TensorInfo> e : tensorInfoMap.entrySet()) {
            if (e.getKey().endsWith(".qb")) continue;
            counts.put(e.getValue().dType, counts.getOrDefault((Object)e.getValue().dType, 0) + 1);
        }
        int max = 0;
        DType maxType = null;
        for (Map.Entry e : counts.entrySet()) {
            if ((Integer)e.getValue() <= max) continue;
            max = (Integer)e.getValue();
            maxType = e.getKey();
        }
        return maxType == DType.F16 ? DType.F32 : maxType;
    }

    @Override
    public Map<String, String> metadata() {
        return this.metadata;
    }

    @Override
    public Map<String, TensorInfo> tensorInfoMap() {
        return this.tensorInfoMap;
    }

    @Override
    public AbstractTensor load(String name, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) {
        TensorInfo info = this.tensorInfoMap.get(name);
        if (info == null) {
            throw new NoSuchElementException(name + " not found in weights");
        }
        if (info.shape.length < 1) {
            throw new RuntimeException("Invalid shape dimensions " + info.shape.length + " encountered for " + name);
        }
        if (dctx != null && info.shape.length != 2) {
            throw new RuntimeException("Invalid shape dimensions " + info.shape.length + " encountered for " + name + " with offset");
        }
        Pair<TensorShape, Pair<Long, Long>> offsets = Weights.getLoadOffsets(info, dctx, sparseRows);
        ByteBuffer b = this.bytes.duplicate().order(ByteOrder.LITTLE_ENDIAN).position(Ints.checkedCast((long)((Long)((Pair)offsets.right).left))).limit(Ints.checkedCast((long)((Long)((Pair)offsets.right).right)));
        return Weights.loadTensorFromBuffer(name, info.dType, this.majorityDType, (TensorShape)offsets.left, b, sparseRows, sparseColumns, dctx, this.parent.orElse(this));
    }

    static Pair<TensorShape, Pair<Long, Long>> getLoadOffsets(TensorInfo info, DistributedContext dctx, boolean sparseRows) {
        long positionOffset = info.dataOffsets[0];
        long positionLimit = info.dataOffsets[1];
        TensorShape shape = TensorShape.of(info.shape);
        if (dctx != null && sparseRows) {
            int rows = info.shape[0];
            int columnLength = info.shape[1] * info.dType.size();
            if (info.dType == DType.Q4) {
                columnLength /= 2;
            }
            positionOffset = info.dataOffsets[0] + (long)(dctx.getShardOffsetForLength(rows) * columnLength);
            positionLimit = positionOffset + (long)(dctx.getShardLength(rows) * columnLength);
            shape = TensorShape.sparseRow(info.shape, Pair.of(dctx.getShardOffsetForLength(rows), dctx.getShardLength(rows)));
        }
        return Pair.of(shape, Pair.of(positionOffset, positionLimit));
    }

    static AbstractTensor loadTensorFromBuffer(String name, DType dType, DType majorityDType, TensorShape shape, ByteBuffer b, boolean sparseRows, boolean sparseColumns, DistributedContext dctx, WeightLoader loader) {
        AbstractTensor t = switch (dType) {
            case DType.F32 -> {
                FloatBuffer fb = b.asFloatBuffer().slice();
                yield new FloatBufferTensor(name, fb, shape, true);
            }
            case DType.F16 -> {
                if (majorityDType == DType.F32) {
                    int len = b.remaining() / DType.F16.size();
                    ByteBuffer bb = ByteBuffer.allocate(len * DType.F32.size()).order(ByteOrder.LITTLE_ENDIAN);
                    for (int i = 0; i < len * DType.F32.size(); i += DType.F32.size()) {
                        short s = b.getShort();
                        float v = Float.float16ToFloat(s);
                        bb.putFloat(i, v);
                    }
                    yield new FloatBufferTensor(bb.asFloatBuffer(), shape, true);
                }
                ShortBuffer sb = b.asShortBuffer().slice();
                yield new Float16BufferTensor(name, sb, shape, true);
            }
            case DType.BF16 -> {
                if (majorityDType == DType.F32) {
                    int len = b.remaining() / DType.BF16.size();
                    ByteBuffer bb = ByteBuffer.allocate(len * DType.F32.size()).order(ByteOrder.LITTLE_ENDIAN);
                    for (int i = 0; i < len * DType.F32.size(); i += DType.F32.size()) {
                        short s = b.getShort();
                        float v = FloatConversions.bFloat16ToFloat32(s);
                        bb.putFloat(i, v);
                    }
                    yield new FloatBufferTensor(bb.asFloatBuffer(), shape, true);
                }
                ShortBuffer sb = b.asShortBuffer().slice();
                yield new BFloat16BufferTensor(name, sb, shape, true);
            }
            case DType.Q4 -> {
                FloatBufferTensor qb = (FloatBufferTensor)loader.load(name + ".qb", dctx, sparseRows, false);
                yield new Q4ByteBufferTensor(name, b.slice(), qb, shape, true);
            }
            case DType.I8 -> {
                FloatBufferTensor qb1 = (FloatBufferTensor)loader.load(name + ".qb", dctx, sparseRows, false);
                yield new Q8ByteBufferTensor(name, b.slice(), qb1, shape, true);
            }
            default -> throw new IllegalArgumentException("Unsupported Tensor type: " + dType.name() + " for " + name);
        };
        return dctx != null && sparseColumns && dctx.hasModelShard() ? t.sparsify(dctx.getShardOffsetForLength(shape.last()), dctx.getShardLength(shape.last())) : t;
    }

    @Override
    public DType getModelDType() {
        return this.majorityDType;
    }

    public String toString() {
        return "SafeTensor{metadata=" + String.valueOf(this.metadata) + ", tensorInfoMap=" + String.valueOf(this.tensorInfoMap) + ", bytes=" + String.valueOf(this.bytes) + "}";
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Weights weights = (Weights)o;
        return Objects.equals(this.metadata, weights.metadata) && Objects.equals(this.tensorInfoMap, weights.tensorInfoMap);
    }

    public int hashCode() {
        return Objects.hash(this.metadata, this.tensorInfoMap);
    }

    @Override
    public void close() throws Exception {
    }
}

