/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.arrow;

import com.google.flatbuffers.FlatBufferBuilder;
import java.nio.ByteBuffer;
import org.apache.arrow.flatbuf.Buffer;
import org.apache.arrow.flatbuf.Tensor;
import org.apache.arrow.flatbuf.TensorDim;
import org.nd4j.arrow.DataBufferStruct;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class ArrowSerde {
    public static INDArray fromTensor(Tensor tensor) {
        byte b = tensor.typeType();
        int[] shape = new int[tensor.shapeLength()];
        int[] stride = new int[tensor.stridesLength()];
        for (int i = 0; i < shape.length; ++i) {
            shape[i] = (int)tensor.shape(i).size();
            stride[i] = (int)tensor.strides(i);
        }
        int length = ArrayUtil.prod((int[])shape);
        Buffer buffer = tensor.data();
        if (buffer == null) {
            throw new ND4JIllegalStateException("Buffer was not serialized properly.");
        }
        int elementSize = (int)buffer.length() / length;
        int i = 0;
        while (i < stride.length) {
            int n = i++;
            stride[n] = stride[n] / elementSize;
        }
        DataBuffer.Type type = ArrowSerde.typeFromTensorType(b, elementSize);
        DataBuffer dataBuffer = DataBufferStruct.createFromByteBuffer(tensor.getByteBuffer(), (int)tensor.data().offset(), type, length);
        INDArray arr = Nd4j.create((DataBuffer)dataBuffer, (int[])shape);
        arr.setShapeAndStride(shape, stride);
        return arr;
    }

    public static Tensor toTensor(INDArray arr) {
        FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(1024);
        long[] strides = ArrowSerde.getArrowStrides(arr);
        int shapeOffset = ArrowSerde.createDims(bufferBuilder, arr);
        int stridesOffset = Tensor.createStridesVector((FlatBufferBuilder)bufferBuilder, (long[])strides);
        Tensor.startTensor((FlatBufferBuilder)bufferBuilder);
        ArrowSerde.addTypeTypeRelativeToNDArray(bufferBuilder, arr);
        Tensor.addShape((FlatBufferBuilder)bufferBuilder, (int)shapeOffset);
        Tensor.addStrides((FlatBufferBuilder)bufferBuilder, (int)stridesOffset);
        Tensor.addData((FlatBufferBuilder)bufferBuilder, (int)ArrowSerde.addDataForArr(bufferBuilder, arr));
        int endTensor = Tensor.endTensor((FlatBufferBuilder)bufferBuilder);
        Tensor.finishTensorBuffer((FlatBufferBuilder)bufferBuilder, (int)endTensor);
        return Tensor.getRootAsTensor((ByteBuffer)bufferBuilder.dataBuffer());
    }

    public static int addDataForArr(FlatBufferBuilder bufferBuilder, INDArray arr) {
        DataBuffer toAdd = arr.isView() ? arr.dup().data() : arr.data();
        int offset = DataBufferStruct.createDataBufferStruct(bufferBuilder, toAdd);
        int ret = Buffer.createBuffer((FlatBufferBuilder)bufferBuilder, (long)offset, (long)(toAdd.length() * (long)toAdd.getElementSize()));
        return ret;
    }

    public static void addTypeTypeRelativeToNDArray(FlatBufferBuilder bufferBuilder, INDArray arr) {
        switch (arr.data().dataType()) {
            case LONG: 
            case INT: {
                Tensor.addTypeType((FlatBufferBuilder)bufferBuilder, (byte)2);
                break;
            }
            case FLOAT: {
                Tensor.addTypeType((FlatBufferBuilder)bufferBuilder, (byte)3);
                break;
            }
            case DOUBLE: {
                Tensor.addTypeType((FlatBufferBuilder)bufferBuilder, (byte)7);
            }
        }
    }

    public static int createDims(FlatBufferBuilder bufferBuilder, INDArray arr) {
        int[] tensorDimOffsets = new int[arr.rank()];
        int[] nameOffset = new int[arr.rank()];
        for (int i = 0; i < tensorDimOffsets.length; ++i) {
            nameOffset[i] = bufferBuilder.createString((CharSequence)"");
            tensorDimOffsets[i] = TensorDim.createTensorDim((FlatBufferBuilder)bufferBuilder, (long)arr.size(i), (int)nameOffset[i]);
        }
        return Tensor.createShapeVector((FlatBufferBuilder)bufferBuilder, (int[])tensorDimOffsets);
    }

    public static long[] getArrowStrides(INDArray arr) {
        long[] ret = new long[arr.rank()];
        for (int i = 0; i < arr.rank(); ++i) {
            ret[i] = arr.stride(i) * arr.data().getElementSize();
        }
        return ret;
    }

    public static DataBuffer.Type typeFromTensorType(byte type, int elementSize) {
        if (type == 3) {
            return DataBuffer.Type.FLOAT;
        }
        if (type == 7) {
            return DataBuffer.Type.DOUBLE;
        }
        if (type == 2) {
            if (elementSize == 4) {
                return DataBuffer.Type.INT;
            }
            if (elementSize == 8) {
                return DataBuffer.Type.LONG;
            }
        } else {
            throw new IllegalArgumentException("Only valid types are Type.Decimal and Type.Int");
        }
        throw new IllegalArgumentException("Unable to determine data type");
    }
}

