/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.preprocessors;

import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties(value={"hasMiniBatchDimension", "miniBatchSize"})
public class ReshapePreprocessor
extends BaseInputPreProcessor {
    private static final Logger log = LoggerFactory.getLogger(ReshapePreprocessor.class);
    private long[] inputShape;
    private long[] targetShape;
    private boolean hasMiniBatchDimension = false;
    private int miniBatchSize;

    public ReshapePreprocessor(@JsonProperty(value="inputShape") long[] inputShape, @JsonProperty(value="targetShape") long[] targetShape) {
        this.inputShape = inputShape;
        this.targetShape = targetShape;
    }

    private static int prod(int[] array) {
        int prod = 1;
        for (int i : array) {
            prod *= i;
        }
        return prod;
    }

    private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
        int shapeLength = shape.length;
        long[] miniBatchShape = new long[shapeLength + 1];
        for (int i = 0; i < miniBatchShape.length; ++i) {
            miniBatchShape[i] = i == 0 ? miniBatchSize : shape[i - 1];
        }
        return miniBatchShape;
    }

    public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        if (!this.hasMiniBatchDimension) {
            this.targetShape = ReshapePreprocessor.prependMiniBatchSize(this.targetShape, miniBatchSize);
            this.inputShape = ReshapePreprocessor.prependMiniBatchSize(this.inputShape, miniBatchSize);
            this.hasMiniBatchDimension = true;
            this.miniBatchSize = miniBatchSize;
        }
        if (this.miniBatchSize != miniBatchSize) {
            this.targetShape = ReshapePreprocessor.prependMiniBatchSize(ArrayUtils.subarray((long[])this.targetShape, (int)1, (int)this.targetShape.length), miniBatchSize);
            this.inputShape = ReshapePreprocessor.prependMiniBatchSize(ArrayUtils.subarray((long[])this.inputShape, (int)1, (int)this.targetShape.length), miniBatchSize);
            this.miniBatchSize = miniBatchSize;
        }
        if (ArrayUtil.prodLong((long[])input.shape()) == ArrayUtil.prodLong((long[])this.targetShape)) {
            if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)input)) {
                input = workspaceMgr.dup((Enum)ArrayType.ACTIVATIONS, input, 'c');
            }
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(this.targetShape));
        }
        throw new IllegalStateException("Input shape " + Arrays.toString(input.shape()) + " and output shape" + Arrays.toString(this.inputShape) + " do not match");
    }

    public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        if (!Arrays.equals(this.targetShape, output.shape())) {
            throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape()) + " (expected to be " + Arrays.toString(this.targetShape) + ")");
        }
        if (ArrayUtil.prodLong((long[])output.shape()) == ArrayUtil.prodLong((long[])this.targetShape)) {
            if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)output)) {
                output = workspaceMgr.dup((Enum)ArrayType.ACTIVATIONS, output, 'c');
            }
            return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(this.inputShape));
        }
        throw new IllegalStateException("Output shape" + Arrays.toString(output.shape()) + " and input shape" + Arrays.toString(this.targetShape) + " do not match");
    }

    public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
        long[] shape = this.hasMiniBatchDimension ? this.targetShape : ReshapePreprocessor.prependMiniBatchSize(this.targetShape, 0L);
        switch (shape.length) {
            case 2: {
                return InputType.feedForward((long)shape[1]);
            }
            case 3: {
                return InputType.recurrent((long)shape[2], (long)shape[1]);
            }
            case 4: {
                if (this.inputShape.length == 1) {
                    return InputType.convolutional((long)shape[1], (long)shape[2], (long)shape[3]);
                }
                return InputType.convolutional((long)shape[2], (long)shape[3], (long)shape[1]);
            }
        }
        throw new UnsupportedOperationException("Cannot infer input type for reshape array " + Arrays.toString(shape));
    }

    public long[] getInputShape() {
        return this.inputShape;
    }

    public long[] getTargetShape() {
        return this.targetShape;
    }

    public boolean isHasMiniBatchDimension() {
        return this.hasMiniBatchDimension;
    }

    public int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    public void setInputShape(long[] inputShape) {
        this.inputShape = inputShape;
    }

    public void setTargetShape(long[] targetShape) {
        this.targetShape = targetShape;
    }

    public void setHasMiniBatchDimension(boolean hasMiniBatchDimension) {
        this.hasMiniBatchDimension = hasMiniBatchDimension;
    }

    public void setMiniBatchSize(int miniBatchSize) {
        this.miniBatchSize = miniBatchSize;
    }

    public String toString() {
        return "ReshapePreprocessor(inputShape=" + Arrays.toString(this.getInputShape()) + ", targetShape=" + Arrays.toString(this.getTargetShape()) + ", hasMiniBatchDimension=" + this.isHasMiniBatchDimension() + ", miniBatchSize=" + this.getMiniBatchSize() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ReshapePreprocessor)) {
            return false;
        }
        ReshapePreprocessor other = (ReshapePreprocessor)((Object)o);
        if (!other.canEqual((Object)this)) {
            return false;
        }
        if (!Arrays.equals(this.getInputShape(), other.getInputShape())) {
            return false;
        }
        if (!Arrays.equals(this.getTargetShape(), other.getTargetShape())) {
            return false;
        }
        if (this.isHasMiniBatchDimension() != other.isHasMiniBatchDimension()) {
            return false;
        }
        return this.getMiniBatchSize() == other.getMiniBatchSize();
    }

    protected boolean canEqual(Object other) {
        return other instanceof ReshapePreprocessor;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + Arrays.hashCode(this.getInputShape());
        result = result * 59 + Arrays.hashCode(this.getTargetShape());
        result = result * 59 + (this.isHasMiniBatchDimension() ? 79 : 97);
        result = result * 59 + this.getMiniBatchSize();
        return result;
    }
}

