/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.random.impl;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class Range
extends DynamicCustomOp {
    private Double from;
    private Double to;
    private Double delta;

    public Range() {
    }

    public Range(SameDiff sd, double from, double to, double step) {
        super(null, sd, new SDVariable[0]);
        this.addTArgument(from, to, step);
        this.from = from;
        this.to = to;
        this.delta = step;
    }

    @Override
    public int opNum() {
        return 4;
    }

    @Override
    public String opName() {
        return "range";
    }

    @Override
    public String onnxName() {
        return "Range";
    }

    @Override
    public String tensorflowName() {
        return "Range";
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
        NodeDef startNode = null;
        NodeDef endNode = null;
        NodeDef deltaNode = null;
        for (NodeDef node : graph.getNodeList()) {
            if (node.getName().equals(nodeDef.getInput(0))) {
                startNode = node;
            }
            if (node.getName().equals(nodeDef.getInput(1))) {
                endNode = node;
            }
            if (node.getName().equals(nodeDef.getInput(2))) {
                deltaNode = node;
            }
            if (startNode == null || endNode == null || deltaNode == null) continue;
            break;
        }
        INDArray start = TFGraphMapper.getInstance().getNDArrayFromTensor("value", startNode, graph);
        INDArray end = TFGraphMapper.getInstance().getNDArrayFromTensor("value", endNode, graph);
        INDArray delta = TFGraphMapper.getInstance().getNDArrayFromTensor("value", deltaNode, graph);
        if (start != null && end != null && delta != null) {
            if (endNode.getName() != null && endNode.getName().equalsIgnoreCase("Rank")) {
                this.from = start.getDouble(0L);
                this.to = this.from + 1.0;
                this.delta = 1.0;
            } else {
                this.from = start.getDouble(0L);
                this.to = end.getDouble(0L);
                this.delta = delta.getDouble(0L);
            }
            SDVariable[] outputVars = this.outputVariables();
            this.addTArgument(this.from, this.to, this.delta);
            String outputVertexId = outputVars[0].getVarName();
            if (this.sameDiff.getArrForVarName(outputVertexId) == null) {
                if (outputVars[0].getShape() == null) {
                    List<long[]> calcShape = this.calculateOutputShape();
                    this.sameDiff.putShapeForVarName(outputVars[0].getVarName(), calcShape.get(0));
                }
                long[] shape = outputVars[0].getShape();
                INDArray arr = Nd4j.create(shape);
                initWith.putArrayForVarName(outputVertexId, arr);
                this.addOutputArgument(arr);
            }
        }
        SDVariable fromVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(startNode.getName()));
        SDVariable toVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(endNode.getName()));
        SDVariable deltaVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(deltaNode.getName()));
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
        super.initFromOnnx(node, initWith, attributesForNode, graph);
    }

    @Override
    public List<long[]> calculateOutputShape() {
        long[] iArgs = this.iArgs();
        double[] tArgs = this.tArgs();
        INDArray[] inputArgs = this.inputArguments();
        int cnt = 0;
        if (iArgs.length > 0) {
            double e;
            int start = (int)iArgs[0];
            int stop = (int)iArgs[1];
            int step = (int)iArgs[2];
            if (start > stop) {
                while (e > (double)stop) {
                    ++cnt;
                    e = (double)step > 0.0 ? e - (double)step : e + (double)step;
                }
            } else {
                for (e = (double)start; e < (double)stop; e += (double)step) {
                    ++cnt;
                }
            }
            return Arrays.asList(new long[][]{{cnt}});
        }
        if (tArgs.length > 0) {
            double e;
            double start = tArgs[0];
            double stop = tArgs[1];
            double step = tArgs[2];
            if (start > stop) {
                while (e > stop) {
                    ++cnt;
                    e = step > 0.0 ? e - step : e + step;
                }
            } else {
                for (e = start; e < stop; e += step) {
                    ++cnt;
                }
            }
            return Arrays.asList(new long[][]{{cnt}});
        }
        if (inputArgs.length > 0) {
            double e;
            double start = inputArgs[0].getDouble(0L);
            double stop = inputArgs[1].getDouble(0L);
            double step = inputArgs[2].getDouble(0L);
            if (start > stop) {
                while (e > stop) {
                    ++cnt;
                    e = step > 0.0 ? e - step : e + step;
                }
            } else {
                for (e = start; e < stop; e += step) {
                    ++cnt;
                }
            }
            return Arrays.asList(new long[][]{{cnt}});
        }
        return Collections.emptyList();
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> f1) {
        return Collections.emptyList();
    }

    @Override
    public Op.Type opType() {
        return Op.Type.CUSTOM;
    }
}

