/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.layers.recurrent;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayerBp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.shade.guava.primitives.Booleans;

public class LSTMLayer
extends DynamicCustomOp {
    private LSTMLayerConfig configuration;
    private LSTMLayerWeights weights;
    private SDVariable cLast;
    private SDVariable yLast;
    private SDVariable maxTSLength;

    public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) {
        super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast));
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        this.configuration = configuration;
        this.weights = weights;
        this.cLast = cLast;
        this.yLast = yLast;
        this.maxTSLength = maxTSLength;
        this.addIArgument(this.iArgs());
        this.addTArgument(this.tArgs());
        this.addBArgument(this.bArgs(weights, maxTSLength, yLast, cLast));
        Preconditions.checkState((this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence() ? 1 : 0) != 0, (String)"You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence  methods  in LSTMLayerConfig builder to specify them");
    }

    public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMLayerWeights lstmWeights, LSTMLayerConfig LSTMLayerConfig2) {
        super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast));
        this.configuration = LSTMLayerConfig2;
        this.weights = lstmWeights;
        this.addIArgument(this.iArgs());
        this.addTArgument(this.tArgs());
        this.addBArgument(this.bArgs(this.weights, maxTSLength, yLast, cLast));
        Preconditions.checkState((this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence() ? 1 : 0) != 0, (String)"You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence  methods  in LSTMLayerConfig builder to specify them");
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
        Preconditions.checkState((inputDataTypes != null && 3 <= inputDataTypes.size() && inputDataTypes.size() <= 8 ? 1 : 0) != 0, (String)"Expected amount of inputs to LSTMLayer between 3 inputs minimum (input, Wx, Wr only) or 8 maximum, got %s", inputDataTypes);
        DataType dt = inputDataTypes.get(1);
        ArrayList<DataType> list = new ArrayList<DataType>();
        if (this.configuration.isRetFullSequence()) {
            list.add(dt);
        }
        if (this.configuration.isRetLastC()) {
            list.add(dt);
        }
        if (this.configuration.isRetLastH()) {
            list.add(dt);
        }
        Preconditions.checkState((boolean)dt.isFPType(), (String)"Input type 1 must be a floating point type, got %s", (Object)((Object)dt));
        return list;
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> grads) {
        int i = 0;
        SDVariable grad0 = this.configuration.isRetFullSequence() ? grads.get(i++) : null;
        SDVariable grad1 = this.configuration.isRetLastH() ? grads.get(i++) : null;
        SDVariable grad2 = this.configuration.isRetLastC() ? grads.get(i++) : null;
        return Arrays.asList(new LSTMLayerBp(this.sameDiff, this.arg(0), this.cLast, this.yLast, this.maxTSLength, this.weights, this.configuration, grad0, grad1, grad2).outputVariables());
    }

    @Override
    public String opName() {
        return "lstmLayer";
    }

    @Override
    public Map<String, Object> propertiesForFunction() {
        return this.configuration.toProperties(true, true);
    }

    @Override
    public long[] iArgs() {
        return new long[]{this.configuration.getLstmdataformat().ordinal(), this.configuration.getDirectionMode().ordinal(), this.configuration.getGateAct().ordinal(), this.configuration.getOutAct().ordinal(), this.configuration.getCellAct().ordinal()};
    }

    @Override
    public double[] tArgs() {
        return new double[]{this.configuration.getCellClip()};
    }

    protected <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
        return new boolean[]{weights.hasBias(), maxTSLength != null, yLast != null, cLast != null, weights.hasPH(), this.configuration.isRetFullSequence(), this.configuration.isRetLastH(), this.configuration.isRetLastC()};
    }

    @Override
    public boolean isConfigProperties() {
        return true;
    }

    @Override
    public String configFieldName() {
        return "configuration";
    }

    @Override
    public int getNumOutputs() {
        return Booleans.countTrue((boolean[])new boolean[]{this.configuration.isRetFullSequence(), this.configuration.isRetLastH(), this.configuration.isRetLastC()});
    }

    public LSTMLayer() {
    }

    public LSTMLayerConfig getConfiguration() {
        return this.configuration;
    }

    public LSTMLayerWeights getWeights() {
        return this.weights;
    }
}

