/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor.classimbalance;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;

public abstract class BaseUnderSamplingPreProcessor {
    protected int tbpttWindowSize;
    private boolean maskAllMajorityWindows = true;
    private boolean donotMaskMinorityWindows = false;

    public void donotMaskAllMajorityWindows() {
        this.maskAllMajorityWindows = false;
    }

    public void donotMaskMinorityWindows() {
        this.donotMaskMinorityWindows = true;
    }

    public INDArray adjustMasks(INDArray label, INDArray labelMask, int minorityLabel, double targetDist) {
        if (labelMask == null) {
            labelMask = Nd4j.ones(label.size(0), label.size(2));
        }
        this.validateData(label, labelMask);
        INDArray bernoullis = Nd4j.zeros(labelMask.shape());
        long currentTimeSliceEnd = label.size(2);
        while (currentTimeSliceEnd > 0L) {
            INDArray currentLabel;
            long currentTimeSliceStart = Math.max(currentTimeSliceEnd - (long)this.tbpttWindowSize, 0L);
            INDArray currentWindowBernoulli = bernoullis.get(NDArrayIndex.all(), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
            INDArray currentMask = labelMask.get(NDArrayIndex.all(), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
            if (label.size(1) == 2L) {
                currentLabel = label.get(NDArrayIndex.all(), NDArrayIndex.point(minorityLabel), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
            } else {
                currentLabel = label.get(NDArrayIndex.all(), NDArrayIndex.point(0L), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
                if (minorityLabel == 0) {
                    currentLabel = currentLabel.rsub(1.0);
                }
            }
            currentWindowBernoulli.assign(this.calculateBernoulli(currentLabel, currentMask, targetDist));
            currentTimeSliceEnd = currentTimeSliceStart;
        }
        return Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(bernoullis.shape()), bernoullis), Nd4j.getRandom());
    }

    private INDArray calculateBernoulli(INDArray minorityLabels, INDArray labelMask, double targetMinorityDist) {
        INDArray minorityClass = minorityLabels.castTo(Nd4j.defaultFloatingPointType()).muli(labelMask);
        INDArray majorityClass = minorityLabels.rsub(1.0).muli(labelMask);
        if (majorityClass.sumNumber().intValue() == 0 || minorityClass.sumNumber().intValue() > 0 && this.donotMaskMinorityWindows) {
            return labelMask;
        }
        if (minorityClass.sumNumber().intValue() == 0 && !this.maskAllMajorityWindows) {
            return labelMask.muli(1.0 - targetMinorityDist);
        }
        INDArray minoritymajorityRatio = minorityClass.sum(1).div(majorityClass.sum(1));
        INDArray majorityBernoulliP = minoritymajorityRatio.muli(1.0 - targetMinorityDist).divi(targetMinorityDist);
        BooleanIndexing.replaceWhere(majorityBernoulliP, 1.0, Conditions.greaterThan(1.0));
        return majorityClass.muliColumnVector(majorityBernoulliP).addi(minorityClass);
    }

    private void validateData(INDArray label, INDArray labelMask) {
        INDArray floatMask;
        INDArray sum1;
        if (label.rank() != 3) {
            throw new IllegalArgumentException("UnderSamplingByMaskingPreProcessor can only be applied to a time series dataset");
        }
        if (label.size(1) > 2L) {
            throw new IllegalArgumentException("UnderSamplingByMaskingPreProcessor can only be applied to labels that represent binary classes. Label size was found to be " + label.size(1) + ".Expecting size=1 or size=2.");
        }
        if (label.size(1) == 2L && !(sum1 = label.sum(1).mul(labelMask)).equals(floatMask = labelMask.castTo(label.dataType()))) {
            throw new IllegalArgumentException("Labels of size minibatchx2xtimesteps are expected to be one hot." + label.toString() + "\n is not one-hot");
        }
    }
}

