/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.indexing;

import com.google.common.primitives.Longs;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.IntervalIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndexAll;
import org.nd4j.linalg.indexing.NewAxis;
import org.nd4j.linalg.indexing.PointIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShapeOffsetResolution
implements Serializable {
    private static final Logger log = LoggerFactory.getLogger(ShapeOffsetResolution.class);
    private INDArray arr;
    private int[] fixed;
    private int[] prependAxis;
    private long[] offsets;
    private long[] shapes;
    private long[] strides;
    private long offset = -1L;

    public ShapeOffsetResolution(INDArray arr) {
        this.arr = arr;
    }

    public boolean tryShortCircuit(INDArrayIndex ... indexes) {
        int minDimensions;
        int i;
        int pointIndex = 0;
        int interval = 0;
        int newAxis = 0;
        int numAll = 0;
        int numSpecified = 0;
        for (i = 0; i < indexes.length; ++i) {
            if (indexes[i] instanceof PointIndex) {
                ++pointIndex;
            }
            if (indexes[i] instanceof SpecifiedIndex) {
                ++numSpecified;
                continue;
            }
            if (indexes[i] instanceof IntervalIndex && !(indexes[i] instanceof NDArrayIndexAll)) {
                ++interval;
                continue;
            }
            if (indexes[i] instanceof NewAxis) {
                ++newAxis;
                continue;
            }
            if (!(indexes[i] instanceof NDArrayIndexAll)) continue;
            ++numAll;
        }
        if (this.arr.isVector()) {
            if (indexes[0] instanceof NDArrayIndexAll && indexes.length == 1) {
                this.offset = 0L;
                this.shapes = this.arr.shape();
                this.strides = this.arr.stride();
                this.offsets = new long[this.arr.rank()];
                return true;
            }
            if (indexes[0] instanceof PointIndex && indexes[1] instanceof NDArrayIndexAll) {
                this.shapes = new long[2];
                this.strides = new long[2];
                for (i = 0; i < 2; ++i) {
                    this.shapes[i] = 1L;
                    this.strides[i] = this.arr.stride(i);
                }
                this.offsets = new long[this.arr.rank()];
                this.offset = this.arr.isRowVector() ? indexes[0].offset() * this.strides[1] : indexes[0].offset() * this.strides[0];
                return true;
            }
            if (indexes[0] instanceof PointIndex && indexes.length == 1) {
                this.shapes = new long[2];
                this.strides = new long[2];
                for (i = 0; i < 2; ++i) {
                    this.shapes[i] = 1L;
                    this.strides[i] = this.arr.stride(i);
                }
                this.offset = this.arr.isRowVector() ? indexes[0].offset() * this.strides[1] : indexes[0].offset() * this.strides[0];
                return true;
            }
            if (this.arr.isRowVector()) {
                if (this.arr.rank() == 1 && indexes.length == 1 && indexes[0] instanceof IntervalIndex) {
                    this.offset = indexes[0].offset();
                    this.shapes = new long[1];
                    this.shapes[0] = indexes[0].length();
                    this.strides = new long[]{this.arr.stride(0)};
                    this.offsets = new long[1];
                    return true;
                }
                if (indexes[0] instanceof PointIndex) {
                    if (indexes.length > 1 && indexes[1] instanceof IntervalIndex) {
                        this.offset = indexes[1].offset();
                        this.shapes = new long[2];
                        this.shapes[0] = 1L;
                        this.shapes[1] = indexes[1].length();
                        this.strides = new long[2];
                        this.strides[0] = 0L;
                        this.strides[1] = indexes[1].stride();
                        this.offsets = new long[2];
                        return true;
                    }
                } else if (!(indexes[0] instanceof IntervalIndex)) {
                    return false;
                }
            } else if (indexes.length > 1 && indexes[1] instanceof PointIndex) {
                if (indexes[0] instanceof IntervalIndex) {
                    this.offset = indexes[0].offset();
                    this.shapes = new long[2];
                    this.shapes[1] = 1L;
                    this.shapes[0] = indexes[1].length();
                    this.strides = new long[2];
                    this.strides[1] = 0L;
                    this.strides[0] = indexes[1].stride();
                    this.offsets = new long[2];
                    return true;
                }
            } else if (!(indexes[0] instanceof IntervalIndex)) {
                return false;
            }
        }
        if (numSpecified > 0 && interval < 1 && newAxis < 1 && numAll > 0 && pointIndex < 1 && this.arr.rank() == 2) {
            int i2;
            this.shapes = new long[this.arr.rank()];
            this.strides = new long[this.arr.rank()];
            this.offsets = new long[this.arr.rank()];
            this.offset = 0L;
            boolean allSpecified = true;
            for (i2 = 0; i2 < 2; ++i2) {
                allSpecified = allSpecified && indexes[i2] instanceof SpecifiedIndex;
            }
            for (i2 = 0; i2 < this.arr.rank(); ++i2) {
                if (indexes[i2] instanceof SpecifiedIndex) {
                    SpecifiedIndex specifiedIndex = (SpecifiedIndex)indexes[i2];
                    if (specifiedIndex.getIndexes().length >= this.arr.rank()) {
                        return false;
                    }
                    this.shapes[i2] = indexes[i2].length();
                    this.offsets[i2] = indexes[i2].offset();
                    if (!allSpecified || i2 == 0 && allSpecified) {
                        this.offset = this.offsets[i2] * (long)this.arr.stride(i2);
                    }
                    if (indexes[i2].length() != 1L) {
                        this.strides[i2] = (long)this.arr.stride(i2) * specifiedIndex.getIndexes()[i2];
                        continue;
                    }
                    this.strides[i2] = 1L;
                    continue;
                }
                if (indexes[i2] instanceof NDArrayIndexAll) {
                    this.shapes[i2] = this.arr.size(i2);
                    this.strides[i2] = this.arr.tensorAlongDimension(0, i2).elementWiseStride();
                    continue;
                }
                throw new IllegalArgumentException("Illegal opType of index " + indexes[i2].getClass().getName());
            }
            return true;
        }
        if (numSpecified < 1 && interval < 1 && newAxis < 1 && pointIndex > 0 && numAll > 0) {
            minDimensions = Math.max(this.arr.rank() - pointIndex, 2);
            long[] shape = new long[minDimensions];
            Arrays.fill(shape, 1L);
            long[] stride = new long[minDimensions];
            Arrays.fill(stride, (long)this.arr.elementStride());
            long[] offsets = new long[minDimensions];
            long offset = 0L;
            int currIndex = 0;
            int arrIndex = 0;
            for (int i3 = 0; i3 < indexes.length; ++i3) {
                if (indexes[i3] instanceof NDArrayIndexAll) {
                    shape[currIndex] = this.arr.size(arrIndex);
                    stride[currIndex] = this.arr.stride(arrIndex);
                    ++currIndex;
                    ++arrIndex;
                    continue;
                }
                offset += indexes[i3].offset() * (long)this.arr.stride(i3);
                ++arrIndex;
            }
            if (this.arr.isMatrix() && indexes[0] instanceof PointIndex) {
                shape = ArrayUtil.reverseCopy((long[])shape);
                stride = ArrayUtil.reverseCopy((long[])stride);
            } else if (this.arr.isMatrix() && indexes[0] instanceof PointIndex && indexes[1] instanceof IntervalIndex) {
                shape = new long[2];
                shape[0] = 1L;
                IntervalIndex idx = (IntervalIndex)indexes[1];
                shape[1] = idx.length();
            }
            this.strides = stride;
            this.shapes = shape;
            this.offsets = offsets;
            this.offset = offset;
            return true;
        }
        if (numSpecified < 1 && interval > 0 && newAxis < 1 && pointIndex < 1 && numAll > 0) {
            int i4;
            minDimensions = Math.max(this.arr.rank(), 2);
            long[] shape = new long[minDimensions];
            Arrays.fill(shape, 1L);
            long[] stride = new long[minDimensions];
            Arrays.fill(stride, (long)this.arr.elementStride());
            long[] offsets = new long[minDimensions];
            for (i4 = 0; i4 < shape.length; ++i4) {
                if (indexes[i4] instanceof NDArrayIndexAll) {
                    shape[i4] = this.arr.size(i4);
                    stride[i4] = this.arr.stride(i4);
                    offsets[i4] = indexes[i4].offset();
                    continue;
                }
                if (!(indexes[i4] instanceof IntervalIndex)) continue;
                shape[i4] = indexes[i4].length();
                stride[i4] = indexes[i4].stride() * (long)this.arr.stride(i4);
                offsets[i4] = indexes[i4].offset();
            }
            this.shapes = shape;
            this.strides = stride;
            this.offsets = offsets;
            this.offset = 0L;
            for (i4 = 0; i4 < indexes.length; ++i4) {
                this.offset += offsets[i4] * (stride[i4] / indexes[i4].stride());
            }
            return true;
        }
        if (numSpecified < 1 && interval < 1 && newAxis > 0 && pointIndex < 1 && numAll > 0) {
            int i5;
            minDimensions = Math.max(this.arr.rank(), 2) + newAxis;
            long[] shape = new long[minDimensions];
            Arrays.fill(shape, 1L);
            long[] stride = new long[minDimensions];
            Arrays.fill(stride, (long)this.arr.elementStride());
            long[] offsets = new long[minDimensions];
            boolean prependNewAxes = false;
            boolean allFirst = false;
            boolean shapeAxis = false;
            int allEncountered = 0;
            for (i5 = 0; i5 < minDimensions; ++i5) {
                if (i5 >= indexes.length) {
                    shape[i5] = this.arr.size(allEncountered);
                    stride[i5] = this.arr.stride(allEncountered);
                    ++allEncountered;
                    continue;
                }
                if (indexes[i5] instanceof NewAxis || !(indexes[i5] instanceof NDArrayIndexAll)) continue;
                shape[allEncountered] = this.arr.size(allEncountered);
                stride[allEncountered] = this.arr.stride(allEncountered);
                ++allEncountered;
            }
            this.shapes = shape;
            this.strides = stride;
            this.offsets = offsets;
            for (i5 = 0; i5 < indexes.length; ++i5) {
                this.offset += offsets[i5] * (stride[i5] / indexes[i5].stride());
            }
            return true;
        }
        return false;
    }

    public void exec(INDArrayIndex ... indexes) {
        boolean needsFilledIn;
        long[] shape = this.arr.shape();
        if (this.arr.isSparse()) {
            this.resolveFixedDimensionsCOO(indexes);
        }
        for (int i = 0; i < indexes.length; ++i) {
            INDArrayIndex idx = indexes[i];
            if (!(idx instanceof PointIndex) || !(this.arr.isVector() && indexes.length == 1 ? idx.current() >= shape[i + 1] : idx.current() >= shape[i])) continue;
            throw new IllegalArgumentException("INDArrayIndex[" + i + "] is out of bounds (value: " + idx.current() + ")");
        }
        indexes = NDArrayIndex.resolve(this.arr.shapeInfoDataBuffer(), indexes);
        if (this.tryShortCircuit(indexes)) {
            return;
        }
        int numIntervals = 0;
        int newAxesPrepend = 0;
        boolean encounteredAll = false;
        int lastPrependIndex = -1;
        ArrayList<Integer> oneDimensionWithAllEncountered = new ArrayList<Integer>();
        ArrayList<Long> accumShape = new ArrayList<Long>();
        ArrayList<Long> accumStrides = new ArrayList<Long>();
        ArrayList<Long> accumOffsets = new ArrayList<Long>();
        ArrayList<Long> intervalStrides = new ArrayList<Long>();
        ArrayList<Long> pointStrides = new ArrayList<Long>();
        ArrayList<Long> pointOffsets = new ArrayList<Long>();
        int numPointIndexes = 0;
        int shapeIndex = 0;
        int strideIndex = 0;
        ArrayList<Integer> prependNewAxes = new ArrayList<Integer>();
        for (int i = 0; i < indexes.length; ++i) {
            INDArrayIndex idx = indexes[i];
            if (idx instanceof NDArrayIndexAll) {
                encounteredAll = true;
                if (i < this.arr.rank() && this.arr.size(i) == 1L) {
                    oneDimensionWithAllEncountered.add(i);
                }
                if (newAxesPrepend > 0 && lastPrependIndex < 0) {
                    lastPrependIndex = i - 1;
                }
            }
            if (idx instanceof PointIndex) {
                pointOffsets.add(idx.offset());
                pointStrides.add(Long.valueOf(this.arr.stride(strideIndex)));
                ++numPointIndexes;
                ++shapeIndex;
                ++strideIndex;
                if (newAxesPrepend <= 0 || lastPrependIndex >= 0) continue;
                lastPrependIndex = i - 1;
                continue;
            }
            if (idx instanceof NewAxis) {
                accumShape.add(1L);
                accumOffsets.add(0L);
                accumStrides.add(0L);
                prependNewAxes.add(i);
                continue;
            }
            if (idx instanceof IntervalIndex && !(idx instanceof NDArrayIndexAll) || idx instanceof SpecifiedIndex) {
                if (idx instanceof IntervalIndex) {
                    accumStrides.add((long)this.arr.stride(strideIndex) * idx.stride());
                    intervalStrides.add(idx.stride());
                    ++numIntervals;
                } else {
                    accumStrides.add(Long.valueOf(this.arr.stride(strideIndex)));
                }
                accumShape.add(idx.length());
                if (idx instanceof IntervalIndex) {
                    accumOffsets.add(idx.offset());
                } else {
                    accumOffsets.add(idx.offset());
                }
                ++shapeIndex;
                ++strideIndex;
                if (newAxesPrepend <= 0 || lastPrependIndex >= 0) continue;
                lastPrependIndex = i - 1;
                continue;
            }
            accumShape.add(shape[shapeIndex++]);
            accumStrides.add(Long.valueOf(this.arr.stride(strideIndex++)));
            accumOffsets.add(idx.offset());
        }
        while (shapeIndex < shape.length) {
            if (Shape.isVector(shape)) {
                accumShape.add(1L);
                ++shapeIndex;
                continue;
            }
            accumShape.add(shape[shapeIndex++]);
        }
        int delta = shape.length <= 2 ? shape.length : shape.length - numPointIndexes;
        boolean bl = needsFilledIn = accumShape.size() != accumStrides.size() && accumOffsets.size() != accumShape.size();
        while (accumOffsets.size() < delta && needsFilledIn) {
            accumOffsets.add(0L);
        }
        while (accumShape.size() < 2) {
            if (Shape.isRowVectorShape(this.arr.shape())) {
                accumShape.add(0, 1L);
                continue;
            }
            accumShape.add(1L);
        }
        while (strideIndex < accumShape.size()) {
            accumStrides.add(Long.valueOf(this.arr.stride(strideIndex++)));
        }
        int trailingZeroRemove = accumOffsets.size() - 1;
        while (accumOffsets.size() > accumShape.size()) {
            if ((Long)accumOffsets.get(trailingZeroRemove) == 0L) {
                accumOffsets.remove(accumOffsets.size() - 1);
            }
            --trailingZeroRemove;
        }
        if (accumStrides.size() < accumOffsets.size()) {
            accumStrides.addAll(pointStrides);
        }
        while (accumOffsets.size() < accumShape.size()) {
            if (Shape.isRowVectorShape(this.arr.shape())) {
                accumOffsets.add(0, 0L);
                continue;
            }
            accumOffsets.add(0L);
        }
        if (Shape.isMatrix(shape) && indexes[0] instanceof PointIndex && indexes[1] instanceof NDArrayIndexAll) {
            Collections.reverse(accumShape);
        }
        if (this.arr.isMatrix() && indexes[0] instanceof PointIndex && indexes[1] instanceof IntervalIndex) {
            this.shapes = new long[2];
            this.shapes[0] = 1L;
            IntervalIndex idx = (IntervalIndex)indexes[1];
            this.shapes[1] = idx.length();
        } else {
            this.shapes = Longs.toArray(accumShape);
        }
        boolean isColumnVector = Shape.isColumnVectorShape(this.shapes);
        while (accumStrides.size() < accumOffsets.size()) {
            if (!isColumnVector) {
                accumStrides.add(0, Long.valueOf(this.arr.elementStride()));
                continue;
            }
            accumStrides.add(Long.valueOf(this.arr.elementStride()));
        }
        this.strides = Longs.toArray(accumStrides);
        this.offsets = Longs.toArray(accumOffsets);
        if (numPointIndexes > 0 && !pointStrides.isEmpty()) {
            if (newAxesPrepend >= 1) {
                while (pointStrides.size() < accumOffsets.size()) {
                    pointStrides.add(1L);
                }
                for (int i = 0; i < accumStrides.size(); ++i) {
                    if ((Long)accumStrides.get(i) != 0L || indexes[i] instanceof NewAxis || lastPrependIndex > 0) continue;
                    pointStrides.set(i, 0L);
                }
            }
            while (pointOffsets.size() < pointStrides.size()) {
                pointOffsets.add(0L);
            }
            this.offset = this.arr.isRowVector() && !intervalStrides.isEmpty() && (Long)pointOffsets.get(0) == 0L && !(indexes[1] instanceof IntervalIndex) ? indexes[1].offset() : ArrayUtil.dotProductLong2(pointOffsets, pointStrides);
        } else {
            this.offset = 0L;
        }
        this.offset = numIntervals > 0 && this.arr.rank() > 2 ? (encounteredAll && this.arr.size(0) != 1L || indexes[0] instanceof PointIndex ? (this.offset += ArrayUtil.dotProductLong2(accumOffsets, accumStrides)) : (this.offset += ArrayUtil.dotProductLong2(accumOffsets, accumStrides))) : (numIntervals > 0 && this.anyHaveStrideOne(indexes) ? (this.offset += ArrayUtil.calcOffsetLong2(accumShape, accumOffsets, accumStrides)) : (this.offset += ArrayUtil.calcOffsetLong2(accumShape, accumOffsets, accumStrides) / (long)Math.max(1, numIntervals)));
    }

    public void resolveFixedDimensionsCOO(INDArrayIndex ... indexes) {
        this.fixed = new int[this.arr.rank()];
        int j = 0;
        for (int i = 0; i < indexes.length; ++i) {
            if (indexes[i] instanceof PointIndex) {
                this.fixed[j] = 1;
                ++j;
            }
            if (indexes[i] instanceof IntervalIndex || indexes[i] instanceof NDArrayIndexAll) {
                this.fixed[j] = 0;
                ++j;
            }
            if (indexes[i] instanceof SpecifiedIndex) {
                SpecifiedIndex idx = (SpecifiedIndex)indexes[i];
                this.fixed[j] = idx.getIndexes().length == 1 ? 1 : 0;
                ++j;
            }
            if (!(indexes[i] instanceof NewAxis)) continue;
        }
    }

    public void resolveSparseOffsetsCOO() {
    }

    private boolean anyHaveStrideOne(INDArrayIndex ... indexes) {
        for (INDArrayIndex indArrayIndex : indexes) {
            if (indArrayIndex.stride() != 1L) continue;
            return true;
        }
        return false;
    }

    private boolean allIndexGreatherThanZero(INDArrayIndex ... indexes) {
        for (INDArrayIndex indArrayIndex : indexes) {
            if (indArrayIndex.offset() != 0L) continue;
            return false;
        }
        return true;
    }

    public INDArray getArr() {
        return this.arr;
    }

    public int[] getFixed() {
        return this.fixed;
    }

    public int[] getPrependAxis() {
        return this.prependAxis;
    }

    public long[] getOffsets() {
        return this.offsets;
    }

    public long[] getShapes() {
        return this.shapes;
    }

    public long[] getStrides() {
        return this.strides;
    }

    public long getOffset() {
        return this.offset;
    }

    public void setArr(INDArray arr) {
        this.arr = arr;
    }

    public void setFixed(int[] fixed) {
        this.fixed = fixed;
    }

    public void setPrependAxis(int[] prependAxis) {
        this.prependAxis = prependAxis;
    }

    public void setOffsets(long[] offsets) {
        this.offsets = offsets;
    }

    public void setShapes(long[] shapes) {
        this.shapes = shapes;
    }

    public void setStrides(long[] strides) {
        this.strides = strides;
    }

    public void setOffset(long offset) {
        this.offset = offset;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ShapeOffsetResolution)) {
            return false;
        }
        ShapeOffsetResolution other = (ShapeOffsetResolution)o;
        if (!other.canEqual(this)) {
            return false;
        }
        INDArray this$arr = this.getArr();
        INDArray other$arr = other.getArr();
        if (this$arr == null ? other$arr != null : !this$arr.equals(other$arr)) {
            return false;
        }
        if (!Arrays.equals(this.getFixed(), other.getFixed())) {
            return false;
        }
        if (!Arrays.equals(this.getPrependAxis(), other.getPrependAxis())) {
            return false;
        }
        if (!Arrays.equals(this.getOffsets(), other.getOffsets())) {
            return false;
        }
        if (!Arrays.equals(this.getShapes(), other.getShapes())) {
            return false;
        }
        if (!Arrays.equals(this.getStrides(), other.getStrides())) {
            return false;
        }
        return this.getOffset() == other.getOffset();
    }

    protected boolean canEqual(Object other) {
        return other instanceof ShapeOffsetResolution;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        INDArray $arr = this.getArr();
        result = result * 59 + ($arr == null ? 43 : $arr.hashCode());
        result = result * 59 + Arrays.hashCode(this.getFixed());
        result = result * 59 + Arrays.hashCode(this.getPrependAxis());
        result = result * 59 + Arrays.hashCode(this.getOffsets());
        result = result * 59 + Arrays.hashCode(this.getShapes());
        result = result * 59 + Arrays.hashCode(this.getStrides());
        long $offset = this.getOffset();
        result = result * 59 + (int)($offset >>> 32 ^ $offset);
        return result;
    }

    public String toString() {
        return "ShapeOffsetResolution(arr=" + this.getArr() + ", fixed=" + Arrays.toString(this.getFixed()) + ", prependAxis=" + Arrays.toString(this.getPrependAxis()) + ", offsets=" + Arrays.toString(this.getOffsets()) + ", shapes=" + Arrays.toString(this.getShapes()) + ", strides=" + Arrays.toString(this.getStrides()) + ", offset=" + this.getOffset() + ")";
    }
}

