/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.AbstractNormalizer;
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;

public abstract class AbstractMultiDataSetNormalizer<S extends NormalizerStats>
extends AbstractNormalizer
implements MultiDataNormalization {
    protected NormalizerStrategy<S> strategy;
    private List<S> featureStats;
    private List<S> labelStats;
    private boolean fitLabels = false;

    protected AbstractMultiDataSetNormalizer() {
    }

    protected AbstractMultiDataSetNormalizer(NormalizerStrategy<S> strategy) {
        this.strategy = strategy;
    }

    public void fitLabel(boolean fitLabels) {
        this.fitLabels = fitLabels;
    }

    public boolean isFitLabel() {
        return this.fitLabels;
    }

    @Override
    protected boolean isFit() {
        return this.featureStats != null;
    }

    protected S getFeatureStats(int input) {
        return (S)((NormalizerStats)this.getFeatureStats().get(input));
    }

    protected List<S> getFeatureStats() {
        this.assertIsFit();
        return this.featureStats;
    }

    protected S getLabelStats(int output) {
        return (S)((NormalizerStats)this.getLabelStats().get(output));
    }

    protected List<S> getLabelStats() {
        this.assertIsFit();
        return this.labelStats;
    }

    @Override
    public void fit(@NonNull MultiDataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        ArrayList<NormalizerStats.Builder> featureNormBuilders = new ArrayList<NormalizerStats.Builder>();
        ArrayList<NormalizerStats.Builder> labelNormBuilders = new ArrayList<NormalizerStats.Builder>();
        this.fitPartial(dataSet, featureNormBuilders, labelNormBuilders);
        this.featureStats = this.buildList(featureNormBuilders);
        if (this.isFitLabel()) {
            this.labelStats = this.buildList(labelNormBuilders);
        }
    }

    @Override
    public void fit(@NonNull MultiDataSetIterator iterator) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        ArrayList<NormalizerStats.Builder> featureNormBuilders = new ArrayList<NormalizerStats.Builder>();
        ArrayList<NormalizerStats.Builder> labelNormBuilders = new ArrayList<NormalizerStats.Builder>();
        iterator.reset();
        while (iterator.hasNext()) {
            MultiDataSet next = (MultiDataSet)iterator.next();
            this.fitPartial(next, featureNormBuilders, labelNormBuilders);
        }
        this.featureStats = this.buildList(featureNormBuilders);
        if (this.isFitLabel()) {
            this.labelStats = this.buildList(labelNormBuilders);
        }
    }

    private List<S> buildList(@NonNull List<NormalizerStats.Builder> builders) {
        if (builders == null) {
            throw new NullPointerException("builders is marked @NonNull but is null");
        }
        ArrayList result = new ArrayList(builders.size());
        for (NormalizerStats.Builder builder : builders) {
            result.add(builder.build());
        }
        return result;
    }

    private void fitPartial(MultiDataSet dataSet, List<NormalizerStats.Builder> featureStatsBuilders, List<NormalizerStats.Builder> labelStatsBuilders) {
        int i;
        int numInputs = dataSet.numFeatureArrays();
        int numOutputs = dataSet.numLabelsArrays();
        this.ensureStatsBuilders(featureStatsBuilders, numInputs);
        this.ensureStatsBuilders(labelStatsBuilders, numOutputs);
        for (i = 0; i < numInputs; ++i) {
            featureStatsBuilders.get(i).add(dataSet.getFeatures(i), dataSet.getFeaturesMaskArray(i));
        }
        if (this.isFitLabel()) {
            for (i = 0; i < numOutputs; ++i) {
                labelStatsBuilders.get(i).add(dataSet.getLabels(i), dataSet.getLabelsMaskArray(i));
            }
        }
    }

    private void ensureStatsBuilders(List<NormalizerStats.Builder> builders, int amount) {
        if (builders.isEmpty()) {
            for (int i = 0; i < amount; ++i) {
                builders.add(this.newBuilder());
            }
        }
    }

    protected abstract NormalizerStats.Builder newBuilder();

    @Override
    public void transform(@NonNull MultiDataSet toPreProcess) {
        if (toPreProcess == null) {
            throw new NullPointerException("toPreProcess is marked @NonNull but is null");
        }
        this.preProcess(toPreProcess);
    }

    @Override
    public void preProcess(@NonNull MultiDataSet toPreProcess) {
        int i;
        if (toPreProcess == null) {
            throw new NullPointerException("toPreProcess is marked @NonNull but is null");
        }
        int numFeatures = toPreProcess.numFeatureArrays();
        int numLabels = toPreProcess.numLabelsArrays();
        for (i = 0; i < numFeatures; ++i) {
            this.strategy.preProcess(toPreProcess.getFeatures(i), toPreProcess.getFeaturesMaskArray(i), this.getFeatureStats(i));
        }
        if (this.isFitLabel()) {
            for (i = 0; i < numLabels; ++i) {
                this.strategy.preProcess(toPreProcess.getLabels(i), toPreProcess.getLabelsMaskArray(i), this.getLabelStats(i));
            }
        }
    }

    @Override
    public void revert(@NonNull MultiDataSet data) {
        if (data == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        this.revertFeatures(data.getFeatures(), data.getFeaturesMaskArrays());
        this.revertLabels(data.getLabels(), data.getLabelsMaskArrays());
    }

    @Override
    public void revertFeatures(@NonNull INDArray[] features) {
        if (features == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        this.revertFeatures(features, null);
    }

    @Override
    public void revertFeatures(@NonNull INDArray[] features, INDArray[] maskArrays) {
        if (features == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        for (int i = 0; i < features.length; ++i) {
            INDArray mask = maskArrays == null ? null : maskArrays[i];
            this.revertFeatures(features[i], mask, i);
        }
    }

    public void revertFeatures(@NonNull INDArray features, INDArray mask, int input) {
        if (features == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        this.strategy.revert(features, mask, this.getFeatureStats(input));
    }

    @Override
    public void revertLabels(INDArray[] labels) {
        this.revertLabels(labels, null);
    }

    @Override
    public void revertLabels(@NonNull INDArray[] labels, INDArray[] labelsMask) {
        if (labels == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        for (int i = 0; i < labels.length; ++i) {
            INDArray mask = labelsMask == null ? null : labelsMask[i];
            this.revertLabels(labels[i], mask, i);
        }
    }

    public void revertLabels(@NonNull INDArray labels, INDArray mask, int output) {
        if (labels == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        if (this.isFitLabel()) {
            this.strategy.revert(labels, mask, this.getLabelStats(output));
        }
    }

    public int numInputs() {
        return this.getFeatureStats().size();
    }

    public int numOutputs() {
        return this.getLabelStats().size();
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AbstractMultiDataSetNormalizer)) {
            return false;
        }
        AbstractMultiDataSetNormalizer other = (AbstractMultiDataSetNormalizer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        NormalizerStrategy<S> this$strategy = this.strategy;
        NormalizerStrategy<S> other$strategy = other.strategy;
        if (this$strategy == null ? other$strategy != null : !this$strategy.equals(other$strategy)) {
            return false;
        }
        List<S> this$featureStats = this.getFeatureStats();
        List<S> other$featureStats = other.getFeatureStats();
        if (this$featureStats == null ? other$featureStats != null : !((Object)this$featureStats).equals(other$featureStats)) {
            return false;
        }
        List<S> this$labelStats = this.getLabelStats();
        List<S> other$labelStats = other.getLabelStats();
        if (this$labelStats == null ? other$labelStats != null : !((Object)this$labelStats).equals(other$labelStats)) {
            return false;
        }
        return this.fitLabels == other.fitLabels;
    }

    protected boolean canEqual(Object other) {
        return other instanceof AbstractMultiDataSetNormalizer;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        NormalizerStrategy<S> $strategy = this.strategy;
        result = result * 59 + ($strategy == null ? 43 : $strategy.hashCode());
        List<S> $featureStats = this.getFeatureStats();
        result = result * 59 + ($featureStats == null ? 43 : ((Object)$featureStats).hashCode());
        List<S> $labelStats = this.getLabelStats();
        result = result * 59 + ($labelStats == null ? 43 : ((Object)$labelStats).hashCode());
        result = result * 59 + (this.fitLabels ? 79 : 97);
        return result;
    }

    public void setFeatureStats(List<S> featureStats) {
        this.featureStats = featureStats;
    }

    public void setLabelStats(List<S> labelStats) {
        this.labelStats = labelStats;
    }
}

