/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.transforms;

import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.Exp;
import org.nd4j.linalg.factory.Nd4j;

public class OldSoftMax
extends BaseTransformOp {
    public OldSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
        super(sameDiff, i_v1, i_v2);
    }

    public OldSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
        super(sameDiff, i_v1, i_v2, inPlace);
    }

    public OldSoftMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
        super(sameDiff, i_v, inPlace);
    }

    public OldSoftMax(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, Object[] extraArgs) {
        super(sameDiff, i_v, shape, inPlace, extraArgs);
    }

    public OldSoftMax(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs) {
        super(sameDiff, i_v, extraArgs);
    }

    public OldSoftMax() {
    }

    public OldSoftMax(INDArray x, INDArray z) {
        this(x, null, z);
    }

    public OldSoftMax(INDArray x, INDArray z, long n) {
        this(x, null, z, n);
    }

    public OldSoftMax(INDArray x, INDArray y, INDArray z, long n) {
        super(x, y, z, n);
        Preconditions.checkArgument((x != null && x.rank() == 2 ? 1 : 0) != 0, (String)"OldSoftMax op supports rank 2 (2d) arrays only. Got x (source) array with shape: %ndShape", (Object)x);
        Preconditions.checkArgument((z != null && z.rank() == 2 ? 1 : 0) != 0, (String)"OldSoftMax op supports rank 2 (2d) arrays only. Got z (result) array with shape: %ndShape", (Object)z);
    }

    public OldSoftMax(INDArray x, INDArray y, INDArray z) {
        this(x, y, z, x.lengthLong());
    }

    public OldSoftMax(INDArray x) {
        super(x);
        Preconditions.checkArgument((x != null && x.rank() == 2 ? 1 : 0) != 0, (String)"OldSoftMax op supports rank 2 (2d) arrays only");
    }

    @Override
    public int opNum() {
        return 38;
    }

    @Override
    public boolean isExecSpecial() {
        return true;
    }

    @Override
    public String opName() {
        return "old_softmax";
    }

    @Override
    public String onnxName() {
        return "Softmax";
    }

    @Override
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + this.opName());
    }

    @Override
    public void exec() {
        this.exec(1);
    }

    @Override
    public void init(INDArray x, INDArray y, INDArray z, long n) {
        super.init(x, y, z, n);
        this.passThrough = true;
    }

    @Override
    public void exec(int ... dimensions) {
        if (dimensions[0] != 1) {
            throw new IllegalArgumentException("Only supports row wise calculations");
        }
        if (this.x.isMatrix()) {
            INDArray maxAlongDimension = this.x.max(dimensions);
            if (!maxAlongDimension.isVector() && !maxAlongDimension.isScalar()) {
                throw new IllegalStateException("Max along dimension for input must either be a row vector or scalar");
            }
            INDArray xMinusMax = this.x.subColumnVector(maxAlongDimension);
            INDArray exp = this.z != null ? Nd4j.getExecutioner().execAndReturn(new Exp(xMinusMax, this.z)) : Nd4j.getExecutioner().execAndReturn(new Exp(xMinusMax));
            INDArray sum = exp.sum(dimensions);
            exp.diviColumnVector(sum);
            if (this.z == null) {
                this.z = exp;
            }
        } else if (this.x.isVector()) {
            double max = this.x.maxNumber().doubleValue();
            INDArray exp = this.z != null ? Nd4j.getExecutioner().execAndReturn(new Exp(this.x.sub(max), this.z)) : Nd4j.getExecutioner().execAndReturn(new Exp(this.x.sub(max)));
            exp.divi(exp.sumNumber().doubleValue());
            this.z = exp;
        }
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> i_v) {
        SDVariable ret = this.f().softmaxDerivative(this.arg(), i_v.get(0), 1);
        return Collections.singletonList(ret);
    }
}

