/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.blas.impl;

import org.nd4j.linalg.api.blas.Level3;
import org.nd4j.linalg.api.blas.impl.BaseLevel;
import org.nd4j.linalg.api.blas.params.GemmParams;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseLevel3
extends BaseLevel
implements Level3 {
    private static final Logger log = LoggerFactory.getLogger(BaseLevel3.class);

    @Override
    public void gemm(char Order, char TransA, char TransB, double alpha, INDArray A, INDArray B, double beta, INDArray C) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(true, A, B, C);
        }
        GemmParams params = new GemmParams(A, B, C);
        char charOder = Order;
        if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, params.getA(), params.getB(), params.getC());
            this.dgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0, params.getA(), params.getLda(), params.getB(), params.getLdb(), 0.0, C, params.getLdc());
        } else if (A.data().dataType() == DataBuffer.Type.FLOAT) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, params.getA(), params.getB(), params.getC());
            this.sgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f, params.getA(), params.getLda(), params.getB(), params.getLdb(), 0.0f, C, params.getLdc());
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, params.getA(), params.getB(), params.getC());
            this.hgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f, params.getA(), params.getLda(), params.getB(), params.getLdb(), 0.0f, C, params.getLdc());
        }
        OpExecutionerUtil.checkForAny(C);
    }

    @Override
    public void gemm(INDArray A, INDArray B, INDArray C, boolean transposeA, boolean transposeB, double alpha, double beta) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(true, A, B, C);
        }
        GemmParams params = new GemmParams(A, B, C, transposeA, transposeB);
        if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, params.getA(), params.getB(), C);
            this.dgemm(A.ordering(), params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), beta, C, params.getLdc());
        } else if (A.data().dataType() == DataBuffer.Type.FLOAT) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, params.getA(), params.getB(), C);
            this.sgemm(A.ordering(), params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), (float)alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), (float)beta, C, params.getLdc());
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, params.getA(), params.getB(), C);
            this.hgemm(A.ordering(), params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), (float)alpha, params.getA(), params.getLda(), params.getB(), params.getLdb(), (float)beta, C, params.getLdc());
        }
        OpExecutionerUtil.checkForAny(C);
    }

    @Override
    public void symm(char Order, char Side, char Uplo, double alpha, INDArray A, INDArray B, double beta, INDArray C) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(false, A, B, C);
        }
        if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, B, C);
            this.dsymm(Order, Side, Uplo, C.rows(), C.columns(), alpha, A, (int)A.size(0), B, (int)B.size(0), beta, C, (int)C.size(0));
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, B, C);
            this.ssymm(Order, Side, Uplo, C.rows(), C.columns(), (float)alpha, A, (int)A.size(0), B, (int)B.size(0), (float)beta, C, (int)C.size(0));
        }
        OpExecutionerUtil.checkForAny(C);
    }

    @Override
    public void syrk(char Order, char Uplo, char Trans, double alpha, INDArray A, double beta, INDArray C) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(false, A, C);
        }
        if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, C);
            this.dsyrk(Order, Uplo, Trans, C.rows(), 1, alpha, A, (int)A.size(0), beta, C, (int)C.size(0));
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, C);
            this.ssyrk(Order, Uplo, Trans, C.rows(), 1, (float)alpha, A, (int)A.size(0), (float)beta, C, (int)C.size(0));
        }
        OpExecutionerUtil.checkForAny(C);
    }

    @Override
    public void syr2k(char Order, char Uplo, char Trans, double alpha, INDArray A, INDArray B, double beta, INDArray C) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(false, A, B, C);
        }
        if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, B, C);
            this.dsyr2k(Order, Uplo, Trans, A.rows(), A.columns(), alpha, A, (int)A.size(0), B, (int)B.size(0), beta, C, (int)C.size(0));
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, B, C);
            this.ssyr2k(Order, Uplo, Trans, A.rows(), A.columns(), (float)alpha, A, (int)A.size(0), B, (int)B.size(0), (float)beta, C, (int)C.size(0));
        }
        OpExecutionerUtil.checkForAny(C);
    }

    @Override
    public void trmm(char Order, char Side, char Uplo, char TransA, char Diag2, double alpha, INDArray A, INDArray B, INDArray C) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(false, A, B, C);
        }
        if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, B, C);
            this.dtrmm(Order, Side, Uplo, TransA, Diag2, A.rows(), A.columns(), alpha, A, (int)A.size(0), B, (int)B.size(0));
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, B, C);
            this.strmm(Order, Side, Uplo, TransA, Diag2, A.rows(), A.columns(), (float)alpha, A, (int)A.size(0), B, (int)B.size(0));
        }
        OpExecutionerUtil.checkForAny(C);
    }

    @Override
    public void trsm(char Order, char Side, char Uplo, char TransA, char Diag2, double alpha, INDArray A, INDArray B) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(false, A, B);
        }
        if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, B);
            this.dtrsm(Order, Side, Uplo, TransA, Diag2, A.rows(), A.columns(), alpha, A, (int)A.size(0), B, (int)B.size(0));
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, B);
            this.strsm(Order, Side, Uplo, TransA, Diag2, A.rows(), A.columns(), (float)alpha, A, (int)A.size(0), B, (int)B.size(0));
        }
        OpExecutionerUtil.checkForAny(B);
    }

    protected abstract void hgemm(char var1, char var2, char var3, int var4, int var5, int var6, float var7, INDArray var8, int var9, INDArray var10, int var11, float var12, INDArray var13, int var14);

    protected abstract void sgemm(char var1, char var2, char var3, int var4, int var5, int var6, float var7, INDArray var8, int var9, INDArray var10, int var11, float var12, INDArray var13, int var14);

    protected abstract void ssymm(char var1, char var2, char var3, int var4, int var5, float var6, INDArray var7, int var8, INDArray var9, int var10, float var11, INDArray var12, int var13);

    protected abstract void ssyrk(char var1, char var2, char var3, int var4, int var5, float var6, INDArray var7, int var8, float var9, INDArray var10, int var11);

    protected abstract void ssyr2k(char var1, char var2, char var3, int var4, int var5, float var6, INDArray var7, int var8, INDArray var9, int var10, float var11, INDArray var12, int var13);

    protected abstract void strmm(char var1, char var2, char var3, char var4, char var5, int var6, int var7, float var8, INDArray var9, int var10, INDArray var11, int var12);

    protected abstract void strsm(char var1, char var2, char var3, char var4, char var5, int var6, int var7, float var8, INDArray var9, int var10, INDArray var11, int var12);

    protected abstract void dgemm(char var1, char var2, char var3, int var4, int var5, int var6, double var7, INDArray var9, int var10, INDArray var11, int var12, double var13, INDArray var15, int var16);

    protected abstract void dsymm(char var1, char var2, char var3, int var4, int var5, double var6, INDArray var8, int var9, INDArray var10, int var11, double var12, INDArray var14, int var15);

    protected abstract void dsyrk(char var1, char var2, char var3, int var4, int var5, double var6, INDArray var8, int var9, double var10, INDArray var12, int var13);

    protected abstract void dsyr2k(char var1, char var2, char var3, int var4, int var5, double var6, INDArray var8, int var9, INDArray var10, int var11, double var12, INDArray var14, int var15);

    protected abstract void dtrmm(char var1, char var2, char var3, char var4, char var5, int var6, int var7, double var8, INDArray var10, int var11, INDArray var12, int var13);

    protected abstract void dtrsm(char var1, char var2, char var3, char var4, char var5, int var6, int var7, double var8, INDArray var10, int var11, INDArray var12, int var13);
}

