/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.safetensors;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.safetensors.TensorInfo;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.Weights;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SafeTensorIndex
implements WeightLoader,
AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(SafeTensorIndex.class);
    private static final ObjectMapper om = new ObjectMapper();
    public static final String SINGLE_MODEL_NAME = "model.safetensors";
    public static final String MODEL_INDEX_JSON = "model.safetensors.index.json";
    private final Map<String, String> metadata;
    final Map<String, String> weightFileMap;
    private final Map<String, Weights> weightMap = new HashMap<String, Weights>();
    private final Map<String, RandomAccessFile> fileMap = new HashMap<String, RandomAccessFile>();

    public static SafeTensorIndex loadWithWeights(Path modelRoot) {
        try {
            File indexFile = Paths.get(modelRoot.toString(), MODEL_INDEX_JSON).toFile();
            SafeTensorIndex index = (SafeTensorIndex)om.readValue(indexFile, SafeTensorIndex.class);
            SafeTensorIndex.loadWeights(index, modelRoot);
            return index;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static SafeTensorIndex loadSingleFile(Path modelRoot, String modelFile) {
        try {
            SafeTensorIndex index = new SafeTensorIndex(Collections.emptyMap(), Map.of("model-file", modelFile));
            SafeTensorIndex.loadWeights(index, modelRoot);
            return index;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static void loadWeights(SafeTensorIndex index, Path modelRoot) throws IOException {
        for (Map.Entry<String, String> e : index.weightFileMap.entrySet()) {
            if (index.fileMap.containsKey(e.getValue())) continue;
            RandomAccessFile raf = new RandomAccessFile(Paths.get(modelRoot.toString(), e.getValue()).toFile(), "r");
            index.fileMap.put(e.getValue(), raf);
            MappedByteBuffer header = raf.getChannel().map(FileChannel.MapMode.READ_ONLY, 0L, Math.min(0x100000L, raf.length()));
            HashMap<String, String> metadata = new HashMap<String, String>();
            Map<String, TensorInfo> tensorInfoMap = SafeTensorSupport.readTensorInfoMap(header, Optional.of(metadata));
            int endOfHeaderPosition = header.position();
            Map<List<Long>, List<String>> splits = index.computeMmapSplits(tensorInfoMap, raf.length());
            for (Map.Entry<List<Long>, List<String>> split : splits.entrySet()) {
                long offset = split.getKey().get(0);
                long length = split.getKey().get(1);
                List<String> tensors = split.getValue();
                MappedByteBuffer buf = raf.getChannel().map(FileChannel.MapMode.READ_ONLY, (long)endOfHeaderPosition + offset, length - offset);
                Map mmapTensorInfoMap = (Map)tensorInfoMap.entrySet().stream().filter(x -> tensors.contains(x.getKey())).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
                Weights mmapWeights = new Weights(metadata, mmapTensorInfoMap, buf, Optional.of(index));
                for (String tensor : tensors) {
                    index.weightMap.put(tensor, mmapWeights);
                }
            }
        }
    }

    private Map<List<Long>, List<String>> computeMmapSplits(Map<String, TensorInfo> tensorInfoMap, long fileLength) {
        HashSet<String> added = new HashSet<String>();
        HashMap<List<Long>, List<String>> splits = new HashMap<List<Long>, List<String>>();
        long lastSplitOffset = 0L;
        while (added.size() < tensorInfoMap.size()) {
            ArrayList<String> tensors = new ArrayList<String>();
            long limit = lastSplitOffset + Integer.MAX_VALUE;
            long startOffset = fileLength;
            long endOffset = 0L;
            for (Map.Entry<String, TensorInfo> e : tensorInfoMap.entrySet()) {
                if (added.contains(e.getKey())) continue;
                TensorInfo info = e.getValue();
                if (info.dataOffsets[1] >= limit) continue;
                tensors.add(e.getKey());
                added.add(e.getKey());
                if (info.dataOffsets[1] > endOffset) {
                    endOffset = info.dataOffsets[1];
                }
                if (info.dataOffsets[0] < startOffset) {
                    startOffset = info.dataOffsets[0];
                }
                info.dataOffsets[0] = info.dataOffsets[0] - lastSplitOffset;
                info.dataOffsets[1] = info.dataOffsets[1] - lastSplitOffset;
                logger.debug("Adding tensor {} to split {}-{}", new Object[]{e.getKey(), info.dataOffsets[0], info.dataOffsets[1]});
            }
            logger.debug("Adding split {}-{} with {} tensors", new Object[]{startOffset, endOffset, tensors.size()});
            assert (endOffset - startOffset < Integer.MAX_VALUE) : "Mmap split too large " + (endOffset - startOffset) + " > 2147483647 " + lastSplitOffset;
            splits.put(List.of(Long.valueOf(startOffset), Long.valueOf(endOffset)), tensors);
            lastSplitOffset = endOffset;
        }
        return splits;
    }

    @JsonCreator
    SafeTensorIndex(@JsonProperty(value="metadata") Map<String, String> metadata, @JsonProperty(value="weight_map") Map<String, String> weightFileMap) {
        this.metadata = ImmutableMap.copyOf(metadata);
        this.weightFileMap = ImmutableMap.copyOf(weightFileMap);
    }

    @Override
    public Map<String, String> metadata() {
        return this.metadata;
    }

    @Override
    public Map<String, TensorInfo> tensorInfoMap() {
        HashMap<String, TensorInfo> tensorInfoMap = new HashMap<String, TensorInfo>();
        for (String name : this.weightMap.keySet()) {
            Weights w = this.weightMap.get(name);
            if (w == null) {
                throw new NoSuchElementException(name);
            }
            tensorInfoMap.put(name, w.tensorInfoMap().get(name));
        }
        return tensorInfoMap;
    }

    @Override
    public AbstractTensor load(String name, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) {
        Weights w = this.weightMap.get(name);
        if (w == null) {
            throw new NoSuchElementException(name);
        }
        return w.load(name, dctx, sparseRows, sparseColumns);
    }

    @Override
    public DType getModelDType() {
        return this.weightMap.values().iterator().next().getModelDType();
    }

    @Override
    public void close() throws Exception {
        this.weightMap.clear();
        this.fileMap.forEach((k, v) -> {
            try {
                v.close();
            }
            catch (IOException iOException) {
                // empty catch block
            }
        });
        this.fileMap.clear();
    }
}

