/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.AbstractNormalizer;
import org.nd4j.linalg.dataset.api.preprocessor.MinMaxStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.StandardizeStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;

public class MultiNormalizerHybrid
extends AbstractNormalizer
implements MultiDataNormalization,
Serializable {
    private Map<Integer, NormalizerStats> inputStats;
    private Map<Integer, NormalizerStats> outputStats;
    private NormalizerStrategy globalInputStrategy;
    private NormalizerStrategy globalOutputStrategy;
    private Map<Integer, NormalizerStrategy> perInputStrategies = new HashMap<Integer, NormalizerStrategy>();
    private Map<Integer, NormalizerStrategy> perOutputStrategies = new HashMap<Integer, NormalizerStrategy>();

    public MultiNormalizerHybrid standardizeAllInputs() {
        this.globalInputStrategy = new StandardizeStrategy();
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleAllInputs() {
        this.globalInputStrategy = new MinMaxStrategy();
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleAllInputs(double rangeFrom, double rangeTo) {
        this.globalInputStrategy = new MinMaxStrategy(rangeFrom, rangeTo);
        return this;
    }

    public MultiNormalizerHybrid standardizeInput(int input) {
        this.perInputStrategies.put(input, new StandardizeStrategy());
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleInput(int input) {
        this.perInputStrategies.put(input, new MinMaxStrategy());
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleInput(int input, double rangeFrom, double rangeTo) {
        this.perInputStrategies.put(input, new MinMaxStrategy(rangeFrom, rangeTo));
        return this;
    }

    public MultiNormalizerHybrid standardizeAllOutputs() {
        this.globalOutputStrategy = new StandardizeStrategy();
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleAllOutputs() {
        this.globalOutputStrategy = new MinMaxStrategy();
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleAllOutputs(double rangeFrom, double rangeTo) {
        this.globalOutputStrategy = new MinMaxStrategy(rangeFrom, rangeTo);
        return this;
    }

    public MultiNormalizerHybrid standardizeOutput(int output) {
        this.perOutputStrategies.put(output, new StandardizeStrategy());
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleOutput(int output) {
        this.perOutputStrategies.put(output, new MinMaxStrategy());
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleOutput(int output, double rangeFrom, double rangeTo) {
        this.perOutputStrategies.put(output, new MinMaxStrategy(rangeFrom, rangeTo));
        return this;
    }

    public NormalizerStats getInputStats(int input) {
        return this.getInputStats().get(input);
    }

    public NormalizerStats getOutputStats(int output) {
        return this.getOutputStats().get(output);
    }

    public Map<Integer, NormalizerStats> getInputStats() {
        this.assertIsFit();
        return this.inputStats;
    }

    public Map<Integer, NormalizerStats> getOutputStats() {
        this.assertIsFit();
        return this.outputStats;
    }

    @Override
    public void fit(@NonNull MultiDataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        HashMap<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<Integer, NormalizerStats.Builder>();
        HashMap<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<Integer, NormalizerStats.Builder>();
        this.fitPartial(dataSet, inputStatsBuilders, outputStatsBuilders);
        this.inputStats = this.buildAllStats(inputStatsBuilders);
        this.outputStats = this.buildAllStats(outputStatsBuilders);
    }

    @Override
    public void fit(@NonNull MultiDataSetIterator iterator) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        HashMap<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<Integer, NormalizerStats.Builder>();
        HashMap<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<Integer, NormalizerStats.Builder>();
        iterator.reset();
        while (iterator.hasNext()) {
            this.fitPartial((MultiDataSet)iterator.next(), inputStatsBuilders, outputStatsBuilders);
        }
        this.inputStats = this.buildAllStats(inputStatsBuilders);
        this.outputStats = this.buildAllStats(outputStatsBuilders);
    }

    private void fitPartial(MultiDataSet dataSet, Map<Integer, NormalizerStats.Builder> inputStatsBuilders, Map<Integer, NormalizerStats.Builder> outputStatsBuilders) {
        this.ensureStatsBuilders(inputStatsBuilders, this.globalInputStrategy, this.perInputStrategies, dataSet.numFeatureArrays());
        this.ensureStatsBuilders(outputStatsBuilders, this.globalOutputStrategy, this.perOutputStrategies, dataSet.numLabelsArrays());
        for (int index : inputStatsBuilders.keySet()) {
            inputStatsBuilders.get(index).add(dataSet.getFeatures(index), dataSet.getFeaturesMaskArray(index));
        }
        for (int index : outputStatsBuilders.keySet()) {
            outputStatsBuilders.get(index).add(dataSet.getLabels(index), dataSet.getLabelsMaskArray(index));
        }
    }

    private void ensureStatsBuilders(Map<Integer, NormalizerStats.Builder> builders, NormalizerStrategy globalStrategy, Map<Integer, NormalizerStrategy> perArrayStrategies, int numArrays) {
        if (builders.isEmpty()) {
            for (int i = 0; i < numArrays; ++i) {
                NormalizerStrategy strategy = this.getStrategy(globalStrategy, perArrayStrategies, i);
                if (strategy == null) continue;
                builders.put(i, strategy.newStatsBuilder());
            }
        }
    }

    private Map<Integer, NormalizerStats> buildAllStats(@NonNull Map<Integer, NormalizerStats.Builder> builders) {
        if (builders == null) {
            throw new NullPointerException("builders is marked @NonNull but is null");
        }
        HashMap<Integer, NormalizerStats> result = new HashMap<Integer, NormalizerStats>(builders.size());
        for (int index : builders.keySet()) {
            result.put(index, (NormalizerStats)builders.get(index).build());
        }
        return result;
    }

    @Override
    public void transform(@NonNull MultiDataSet data) {
        if (data == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        this.preProcess(data);
    }

    @Override
    public void preProcess(@NonNull MultiDataSet data) {
        if (data == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        this.preProcess(data.getFeatures(), data.getFeaturesMaskArrays(), this.globalInputStrategy, this.perInputStrategies, this.getInputStats());
        this.preProcess(data.getLabels(), data.getLabelsMaskArrays(), this.globalOutputStrategy, this.perOutputStrategies, this.getOutputStats());
    }

    private void preProcess(INDArray[] arrays, INDArray[] masks, NormalizerStrategy globalStrategy, Map<Integer, NormalizerStrategy> perArrayStrategy, Map<Integer, NormalizerStats> stats) {
        if (arrays != null) {
            for (int i = 0; i < arrays.length; ++i) {
                NormalizerStrategy strategy = this.getStrategy(globalStrategy, perArrayStrategy, i);
                if (strategy == null) continue;
                strategy.preProcess(arrays[i], masks == null ? null : masks[i], stats.get(i));
            }
        }
    }

    @Override
    public void revert(@NonNull MultiDataSet data) {
        if (data == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        this.revertFeatures(data.getFeatures(), data.getFeaturesMaskArrays());
        this.revertLabels(data.getLabels(), data.getLabelsMaskArrays());
    }

    @Override
    public NormalizerType getType() {
        return NormalizerType.MULTI_HYBRID;
    }

    @Override
    public void revertFeatures(@NonNull INDArray[] features) {
        if (features == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        this.revertFeatures(features, null);
    }

    @Override
    public void revertFeatures(@NonNull INDArray[] features, INDArray[] maskArrays) {
        if (features == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        for (int i = 0; i < features.length; ++i) {
            this.revertFeatures(features, maskArrays, i);
        }
    }

    public void revertFeatures(@NonNull INDArray[] features, INDArray[] maskArrays, int input) {
        if (features == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        NormalizerStrategy strategy = this.getStrategy(this.globalInputStrategy, this.perInputStrategies, input);
        if (strategy != null) {
            INDArray mask = maskArrays == null ? null : maskArrays[input];
            strategy.revert(features[input], mask, this.getInputStats(input));
        }
    }

    @Override
    public void revertLabels(@NonNull INDArray[] labels) {
        if (labels == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        this.revertLabels(labels, null);
    }

    @Override
    public void revertLabels(@NonNull INDArray[] labels, INDArray[] maskArrays) {
        if (labels == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        for (int i = 0; i < labels.length; ++i) {
            this.revertLabels(labels, maskArrays, i);
        }
    }

    public void revertLabels(@NonNull INDArray[] labels, INDArray[] maskArrays, int output) {
        if (labels == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        NormalizerStrategy strategy = this.getStrategy(this.globalOutputStrategy, this.perOutputStrategies, output);
        if (strategy != null) {
            INDArray mask = maskArrays == null ? null : maskArrays[output];
            strategy.revert(labels[output], mask, this.getOutputStats(output));
        }
    }

    private NormalizerStrategy getStrategy(NormalizerStrategy globalStrategy, Map<Integer, NormalizerStrategy> perArrayStrategy, int index) {
        NormalizerStrategy strategy = globalStrategy;
        if (perArrayStrategy.containsKey(index)) {
            strategy = perArrayStrategy.get(index);
        }
        return strategy;
    }

    @Override
    protected boolean isFit() {
        return this.inputStats != null;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MultiNormalizerHybrid)) {
            return false;
        }
        MultiNormalizerHybrid other = (MultiNormalizerHybrid)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Map<Integer, NormalizerStats> this$inputStats = this.getInputStats();
        Map<Integer, NormalizerStats> other$inputStats = other.getInputStats();
        if (this$inputStats == null ? other$inputStats != null : !((Object)this$inputStats).equals(other$inputStats)) {
            return false;
        }
        Map<Integer, NormalizerStats> this$outputStats = this.getOutputStats();
        Map<Integer, NormalizerStats> other$outputStats = other.getOutputStats();
        if (this$outputStats == null ? other$outputStats != null : !((Object)this$outputStats).equals(other$outputStats)) {
            return false;
        }
        NormalizerStrategy this$globalInputStrategy = this.getGlobalInputStrategy();
        NormalizerStrategy other$globalInputStrategy = other.getGlobalInputStrategy();
        if (this$globalInputStrategy == null ? other$globalInputStrategy != null : !this$globalInputStrategy.equals(other$globalInputStrategy)) {
            return false;
        }
        NormalizerStrategy this$globalOutputStrategy = this.getGlobalOutputStrategy();
        NormalizerStrategy other$globalOutputStrategy = other.getGlobalOutputStrategy();
        if (this$globalOutputStrategy == null ? other$globalOutputStrategy != null : !this$globalOutputStrategy.equals(other$globalOutputStrategy)) {
            return false;
        }
        Map<Integer, NormalizerStrategy> this$perInputStrategies = this.getPerInputStrategies();
        Map<Integer, NormalizerStrategy> other$perInputStrategies = other.getPerInputStrategies();
        if (this$perInputStrategies == null ? other$perInputStrategies != null : !((Object)this$perInputStrategies).equals(other$perInputStrategies)) {
            return false;
        }
        Map<Integer, NormalizerStrategy> this$perOutputStrategies = this.getPerOutputStrategies();
        Map<Integer, NormalizerStrategy> other$perOutputStrategies = other.getPerOutputStrategies();
        return !(this$perOutputStrategies == null ? other$perOutputStrategies != null : !((Object)this$perOutputStrategies).equals(other$perOutputStrategies));
    }

    protected boolean canEqual(Object other) {
        return other instanceof MultiNormalizerHybrid;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Map<Integer, NormalizerStats> $inputStats = this.getInputStats();
        result = result * 59 + ($inputStats == null ? 43 : ((Object)$inputStats).hashCode());
        Map<Integer, NormalizerStats> $outputStats = this.getOutputStats();
        result = result * 59 + ($outputStats == null ? 43 : ((Object)$outputStats).hashCode());
        NormalizerStrategy $globalInputStrategy = this.getGlobalInputStrategy();
        result = result * 59 + ($globalInputStrategy == null ? 43 : $globalInputStrategy.hashCode());
        NormalizerStrategy $globalOutputStrategy = this.getGlobalOutputStrategy();
        result = result * 59 + ($globalOutputStrategy == null ? 43 : $globalOutputStrategy.hashCode());
        Map<Integer, NormalizerStrategy> $perInputStrategies = this.getPerInputStrategies();
        result = result * 59 + ($perInputStrategies == null ? 43 : ((Object)$perInputStrategies).hashCode());
        Map<Integer, NormalizerStrategy> $perOutputStrategies = this.getPerOutputStrategies();
        result = result * 59 + ($perOutputStrategies == null ? 43 : ((Object)$perOutputStrategies).hashCode());
        return result;
    }

    public void setInputStats(Map<Integer, NormalizerStats> inputStats) {
        this.inputStats = inputStats;
    }

    public void setOutputStats(Map<Integer, NormalizerStats> outputStats) {
        this.outputStats = outputStats;
    }

    public void setGlobalInputStrategy(NormalizerStrategy globalInputStrategy) {
        this.globalInputStrategy = globalInputStrategy;
    }

    public void setGlobalOutputStrategy(NormalizerStrategy globalOutputStrategy) {
        this.globalOutputStrategy = globalOutputStrategy;
    }

    public void setPerInputStrategies(Map<Integer, NormalizerStrategy> perInputStrategies) {
        this.perInputStrategies = perInputStrategies;
    }

    public void setPerOutputStrategies(Map<Integer, NormalizerStrategy> perOutputStrategies) {
        this.perOutputStrategies = perOutputStrategies;
    }

    public NormalizerStrategy getGlobalInputStrategy() {
        return this.globalInputStrategy;
    }

    public NormalizerStrategy getGlobalOutputStrategy() {
        return this.globalOutputStrategy;
    }

    public Map<Integer, NormalizerStrategy> getPerInputStrategies() {
        return this.perInputStrategies;
    }

    public Map<Integer, NormalizerStrategy> getPerOutputStrategies() {
        return this.perOutputStrategies;
    }
}

