/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.BatchSampler;
import ai.djl.training.dataset.BulkDataIterable;
import ai.djl.training.dataset.DataIterable;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomSampler;
import ai.djl.training.dataset.Record;
import ai.djl.training.dataset.Sampler;
import ai.djl.training.dataset.SequenceSampler;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import ai.djl.util.RandomUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public abstract class RandomAccessDataset
implements Dataset {
    protected Sampler sampler;
    protected Batchifier dataBatchifier;
    protected Batchifier labelBatchifier;
    protected Pipeline pipeline;
    protected Pipeline targetPipeline;
    protected int prefetchNumber;
    protected long limit;
    protected Device device;

    RandomAccessDataset() {
    }

    public RandomAccessDataset(BaseBuilder<?> builder) {
        this.sampler = builder.getSampler();
        this.dataBatchifier = builder.dataBatchifier;
        this.labelBatchifier = builder.labelBatchifier;
        this.pipeline = builder.pipeline;
        this.targetPipeline = builder.targetPipeline;
        this.prefetchNumber = builder.prefetchNumber;
        this.limit = builder.limit;
        this.device = builder.device;
    }

    public abstract Record get(NDManager var1, long var2) throws IOException;

    @Override
    public Iterable<Batch> getData(NDManager manager) throws IOException, TranslateException {
        return this.getData(manager, this.sampler, null);
    }

    @Override
    public Iterable<Batch> getData(NDManager manager, ExecutorService executorService) throws IOException, TranslateException {
        return this.getData(manager, this.sampler, executorService);
    }

    public Iterable<Batch> getData(NDManager manager, Sampler sampler) throws IOException, TranslateException {
        return this.getData(manager, sampler, null);
    }

    public Iterable<Batch> getData(NDManager manager, Sampler sampler, ExecutorService executorService) throws IOException, TranslateException {
        this.prepare();
        return new DataIterable(this, manager, sampler, this.dataBatchifier, this.labelBatchifier, this.pipeline, this.targetPipeline, executorService, this.prefetchNumber, this.device);
    }

    public long size() {
        return Math.min(this.limit, this.availableSize());
    }

    protected abstract long availableSize();

    public RandomAccessDataset[] randomSplit(int ... ratio) throws IOException, TranslateException {
        this.prepare();
        if (ratio.length < 2) {
            throw new IllegalArgumentException("Requires at least two split portion.");
        }
        int size = Math.toIntExact(this.size());
        int[] indices = IntStream.range(0, size).toArray();
        for (int i = 0; i < size; ++i) {
            RandomAccessDataset.swap(indices, i, RandomUtils.nextInt(size));
        }
        RandomAccessDataset[] ret = new RandomAccessDataset[ratio.length];
        double sum = Arrays.stream(ratio).sum();
        int from = 0;
        for (int i = 0; i < ratio.length - 1; ++i) {
            int to = from + (int)((double)ratio[i] / sum * (double)size);
            ret[i] = this.newSubDataset(indices, from, to);
            from = to;
        }
        ret[ratio.length - 1] = this.newSubDataset(indices, from, size);
        return ret;
    }

    public RandomAccessDataset subDataset(int fromIndex, int toIndex) {
        int size = Math.toIntExact(this.size());
        int[] indices = IntStream.range(0, size).toArray();
        return this.newSubDataset(indices, fromIndex, toIndex);
    }

    public RandomAccessDataset subDataset(List<Long> subIndices) {
        if (BulkDataIterable.isRange(subIndices)) {
            int size = Math.toIntExact(this.size());
            int[] indices = IntStream.range(0, size).toArray();
            long fromIndex = subIndices.get(0);
            long toIndex = subIndices.get(0) + (long)subIndices.size();
            return this.newSubDataset(indices, Math.toIntExact(fromIndex), Math.toIntExact(toIndex));
        }
        return this.newSubDataset(subIndices);
    }

    public <K> RandomAccessDataset subDataset(List<K> recordKeys, List<K> subRecordKeys) {
        if (this.size() != (long)recordKeys.size()) {
            throw new IllegalArgumentException("Requires as many record keys as there are records in the dataset.");
        }
        ConcurrentHashMap<K, Long> indicesOfRecordKeys = new ConcurrentHashMap<K, Long>(recordKeys.size());
        for (int index = 0; index < recordKeys.size(); ++index) {
            Long prevIndex = indicesOfRecordKeys.put(recordKeys.get(index), Long.valueOf(index));
            if (prevIndex == null) continue;
            throw new IllegalArgumentException("At least two keys at position " + prevIndex + " and " + index + " are equal!");
        }
        return this.subDataset(indicesOfRecordKeys, subRecordKeys);
    }

    public <K> RandomAccessDataset subDataset(Map<K, Long> indicesOfRecordKeys, List<K> subRecordKeys) {
        ArrayList<Long> subIndices = new ArrayList<Long>(subRecordKeys.size());
        for (K recordKey : subRecordKeys) {
            Long index = indicesOfRecordKeys.get(recordKey);
            if (index == null) {
                throw new IllegalArgumentException("The key of subRecordKeys at position " + subRecordKeys.indexOf(recordKey) + " is not contained in recordKeys!");
            }
            subIndices.add(index);
        }
        return this.subDataset(subIndices);
    }

    protected RandomAccessDataset newSubDataset(int[] indices, int from, int to) {
        return new SubDataset(this, indices, from, to);
    }

    protected RandomAccessDataset newSubDataset(List<Long> subIndices) {
        return new SubDatasetByIndices(this, subIndices);
    }

    public Pair<Number[][], Number[][]> toArray(NDManager manager) throws IOException, TranslateException {
        BatchSampler sampl = new BatchSampler(new SequenceSampler(), 1, false);
        int size = Math.toIntExact(this.size());
        Number[][] data = new Number[size][];
        Number[][] labels = new Number[size][];
        int index = 0;
        for (Batch batch : this.getData(manager, sampl)) {
            data[index] = this.flattenRecord(batch.getData());
            labels[index] = this.flattenRecord(batch.getLabels());
            batch.close();
            ++index;
        }
        return new Pair<Number[][], Number[][]>(data, labels);
    }

    private Number[] flattenRecord(NDList data) {
        NDList flattened = new NDList(data.stream().map(NDArray::flatten).collect(Collectors.toList()));
        if (flattened.size() == 0) {
            return null;
        }
        if (flattened.size() == 1) {
            return ((NDArray)flattened.get(0)).toArray();
        }
        return NDArrays.concat(flattened).toArray();
    }

    private static void swap(int[] arr, int i, int j) {
        int tmp = arr[i];
        arr[i] = arr[j];
        arr[j] = tmp;
    }

    private static final class SubDatasetByIndices
    extends RandomAccessDataset {
        private RandomAccessDataset dataset;
        private List<Long> subIndices;

        public SubDatasetByIndices(RandomAccessDataset dataset, List<Long> subIndices) {
            this.dataset = dataset;
            this.subIndices = subIndices;
            this.sampler = dataset.sampler;
            this.dataBatchifier = dataset.dataBatchifier;
            this.labelBatchifier = dataset.labelBatchifier;
            this.pipeline = dataset.pipeline;
            this.targetPipeline = dataset.targetPipeline;
            this.prefetchNumber = dataset.prefetchNumber;
            this.device = dataset.device;
            this.limit = Long.MAX_VALUE;
        }

        @Override
        public Record get(NDManager manager, long index) throws IOException {
            return this.dataset.get(manager, this.subIndices.get(Math.toIntExact(index)));
        }

        @Override
        protected long availableSize() {
            return this.subIndices.size();
        }

        @Override
        public void prepare(Progress progress) {
        }
    }

    private static final class SubDataset
    extends RandomAccessDataset {
        private RandomAccessDataset dataset;
        private int[] indices;
        private int from;
        private int to;

        public SubDataset(RandomAccessDataset dataset, int[] indices, int from, int to) {
            this.dataset = dataset;
            this.indices = indices;
            this.from = from;
            this.to = to;
            this.sampler = dataset.sampler;
            this.dataBatchifier = dataset.dataBatchifier;
            this.labelBatchifier = dataset.labelBatchifier;
            this.pipeline = dataset.pipeline;
            this.targetPipeline = dataset.targetPipeline;
            this.prefetchNumber = dataset.prefetchNumber;
            this.device = dataset.device;
            this.limit = Long.MAX_VALUE;
        }

        @Override
        public Record get(NDManager manager, long index) throws IOException {
            if (index >= this.size()) {
                throw new IndexOutOfBoundsException("index(" + index + ") > size(" + this.size() + ").");
            }
            return this.dataset.get(manager, this.indices[Math.toIntExact(index) + this.from]);
        }

        @Override
        protected long availableSize() {
            return this.to - this.from;
        }

        @Override
        public void prepare(Progress progress) {
        }
    }

    public static abstract class BaseBuilder<T extends BaseBuilder<T>> {
        protected Sampler sampler;
        protected Batchifier dataBatchifier = Batchifier.STACK;
        protected Batchifier labelBatchifier = Batchifier.STACK;
        protected Pipeline pipeline;
        protected Pipeline targetPipeline;
        protected int prefetchNumber = 2;
        protected long limit = Long.MAX_VALUE;
        protected Device device;

        public Sampler getSampler() {
            Objects.requireNonNull(this.sampler, "The sampler must be set");
            return this.sampler;
        }

        public T setSampling(int batchSize, boolean random) {
            return this.setSampling(batchSize, random, false);
        }

        public T setSampling(int batchSize, boolean random, boolean dropLast) {
            this.sampler = random ? new BatchSampler(new RandomSampler(), batchSize, dropLast) : new BatchSampler(new SequenceSampler(), batchSize, dropLast);
            return this.self();
        }

        public T setSampling(Sampler sampler) {
            this.sampler = sampler;
            return this.self();
        }

        public T optDataBatchifier(Batchifier dataBatchifier) {
            this.dataBatchifier = dataBatchifier;
            return this.self();
        }

        public T optLabelBatchifier(Batchifier labelBatchifier) {
            this.labelBatchifier = labelBatchifier;
            return this.self();
        }

        public T optPipeline(Pipeline pipeline) {
            this.pipeline = pipeline;
            return this.self();
        }

        public T addTransform(Transform transform) {
            if (this.pipeline == null) {
                this.pipeline = new Pipeline();
            }
            this.pipeline.add(transform);
            return this.self();
        }

        public T optTargetPipeline(Pipeline targetPipeline) {
            this.targetPipeline = targetPipeline;
            return this.self();
        }

        public T addTargetTransform(Transform transform) {
            if (this.targetPipeline == null) {
                this.targetPipeline = new Pipeline();
            }
            this.targetPipeline.add(transform);
            return this.self();
        }

        public T optPrefetchNumber(int prefetchNumber) {
            this.prefetchNumber = prefetchNumber;
            return this.self();
        }

        public T optDevice(Device device) {
            this.device = device;
            return this.self();
        }

        public T optLimit(long limit) {
            this.limit = limit;
            return this.self();
        }

        protected abstract T self();
    }
}

