package org.dromara.easyai.unet;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.dromara.easyai.conv.ConvCount;
import org.dromara.easyai.entity.ThreeChannelMatrix;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.nerveEntity.ConvParameter;
import org.dromara.easyai.nerveEntity.ConvSize;

/* loaded from: input_file:org/dromara/easyai/unet/UNetDecoder.class */
public class UNetDecoder extends ConvCount {
    private final int kerSize;
    private final int deep;
    private final float studyRate;
    private final int channelNo;
    private final boolean lastLay;
    private final ActiveFunction activeFunction;
    private UNetDecoder afterDecoder;
    private UNetDecoder beforeDecoder;
    private UNetEncoder encoder;
    private UNetEncoder myUNetEncoder;
    private final Cutting cutting;
    private final float oneConvStudyRate;
    private final ConvParameter convParameter = new ConvParameter();
    private final MatrixOperation matrixOperation = new MatrixOperation();
    private final ConvSize convSize = new ConvSize();

    public UNetDecoder(int i, int i2, int i3, ActiveFunction activeFunction, boolean z, float f, Cutting cutting, float f2) throws Exception {
        this.cutting = cutting;
        this.kerSize = i;
        this.oneConvStudyRate = f2;
        this.deep = i2;
        this.studyRate = f;
        this.lastLay = z;
        this.channelNo = i3;
        this.activeFunction = activeFunction;
        Random random = new Random();
        List<Matrix> nerveMatrixList = this.convParameter.getNerveMatrixList();
        List<Matrix> upNerveMatrixList = this.convParameter.getUpNerveMatrixList();
        List<ConvSize> convSizeList = this.convParameter.getConvSizeList();
        for (int i4 = 0; i4 < i3; i4++) {
            upNerveMatrixList.add(initUpNervePowerMatrix(random));
            initNervePowerMatrix(random, nerveMatrixList);
            convSizeList.add(new ConvSize());
        }
        if (z) {
            ArrayList arrayList = new ArrayList();
            for (int i5 = 0; i5 < i3; i5++) {
                arrayList.add(Float.valueOf(random.nextFloat() / i3));
            }
            this.convParameter.setUpOneConvPower(arrayList);
        }
    }

    public ConvParameter getConvParameter() {
        return this.convParameter;
    }

    private ThreeChannelMatrix fillColor(ThreeChannelMatrix threeChannelMatrix, int i, int i2) throws Exception {
        int x = threeChannelMatrix.getX() - i;
        int i3 = x / 2;
        if (i3 == 0) {
            i3 = 1;
        }
        ThreeChannelMatrix threeChannelMatrix2 = null;
        if (x > 0) {
            threeChannelMatrix2 = threeChannelMatrix.cutChannel(i3, 0, i, i2);
        } else if (x < 0) {
            threeChannelMatrix2 = getFaceMatrix(i, i2);
            threeChannelMatrix2.fill(Math.abs(i3), 0, threeChannelMatrix);
        }
        return threeChannelMatrix2;
    }

    private ThreeChannelMatrix getFaceMatrix(int i, int i2) {
        ThreeChannelMatrix threeChannelMatrix = new ThreeChannelMatrix();
        Matrix matrix = new Matrix(i, i2);
        Matrix matrix2 = new Matrix(i, i2);
        Matrix matrix3 = new Matrix(i, i2);
        Matrix matrix4 = new Matrix(i, i2);
        threeChannelMatrix.setX(i);
        threeChannelMatrix.setY(i2);
        threeChannelMatrix.setMatrixR(matrix);
        threeChannelMatrix.setMatrixG(matrix2);
        threeChannelMatrix.setMatrixB(matrix3);
        threeChannelMatrix.setH(matrix4);
        return threeChannelMatrix;
    }

    private void addFeatures(List<Matrix> list, List<Matrix> list2, boolean z) throws Exception {
        int size = list.size();
        for (int i = 0; i < size; i++) {
            addFeature(list.get(i), list2.get(i), z);
        }
    }

    private void addFeature(Matrix matrix, Matrix matrix2, boolean z) throws Exception {
        if (z) {
            this.convSize.setXInput(matrix.getX());
            this.convSize.setYInput(matrix.getY());
        }
        int x = matrix.getX();
        int y = matrix.getY();
        int x2 = matrix2.getX();
        int y2 = matrix2.getY();
        for (int i = 0; i < x2; i++) {
            for (int i2 = 0; i2 < y2; i2++) {
                float f = 0.0f;
                if (i < x && i2 < y) {
                    f = matrix.getNumber(i, i2);
                }
                matrix2.setNub(i, i2, (matrix2.getNumber(i, i2) + f) / 2.0f);
            }
        }
    }

    private void toThreeChannelMatrix(List<Matrix> list, ThreeChannelMatrix threeChannelMatrix, boolean z, OutBack outBack, ThreeChannelMatrix threeChannelMatrix2) throws Exception {
        int x = list.get(0).getX();
        int y = list.get(0).getY();
        List<Float> upOneConvPower = this.convParameter.getUpOneConvPower();
        Matrix oneConv = oneConv(list, upOneConvPower);
        if (z) {
            ThreeChannelMatrix scale = threeChannelMatrix.scale(true, y);
            ThreeChannelMatrix fillColor = fillColor(scale, x, y);
            if (fillColor == null) {
                fillColor = scale;
            }
            Matrix sub = this.matrixOperation.sub(fillColor.CalculateAvgGrayscale(), oneConv);
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < this.channelNo; i++) {
                arrayList.add(this.matrixOperation.mathMulBySelf(sub, upOneConvPower.get(i).floatValue()));
            }
            backOneConv(sub, list, upOneConvPower, this.oneConvStudyRate, true);
            backLastError(arrayList);
            return;
        }
        int x2 = threeChannelMatrix2.getX();
        int y2 = threeChannelMatrix2.getY();
        int x3 = (x2 - oneConv.getX()) / 2;
        int y3 = (y2 - oneConv.getY()) / 2;
        Matrix matrix = new Matrix(x2, y2);
        for (int i2 = x3; i2 < x; i2++) {
            for (int i3 = y3; i3 < y; i3++) {
                matrix.setNub(i2, i3, oneConv.getNumber(i2 - x3, i3 - y3));
            }
        }
        ThreeChannelMatrix threeChannelMatrix3 = new ThreeChannelMatrix();
        threeChannelMatrix3.setX(x);
        threeChannelMatrix3.setY(y);
        threeChannelMatrix3.setMatrixR(matrix);
        threeChannelMatrix3.setMatrixG(matrix);
        threeChannelMatrix3.setMatrixB(matrix);
        if (this.cutting != null) {
            this.cutting.cut(threeChannelMatrix2, threeChannelMatrix3, outBack);
        } else {
            outBack.getBackThreeChannelMatrix(threeChannelMatrix3);
        }
    }

    private void backLastError(List<Matrix> list) throws Exception {
        List<Matrix> backAllDownConv = backAllDownConv(this.convParameter, list, this.studyRate, this.activeFunction, this.channelNo, this.kerSize);
        sendEncoderError(backAllDownConv);
        this.beforeDecoder.backErrorMatrix(backAllDownConv);
    }

    private void sendEncoderError(List<Matrix> list) throws Exception {
        ArrayList arrayList = new ArrayList();
        for (Matrix matrix : list) {
            Matrix matrix2 = new Matrix(this.convSize.getXInput(), this.convSize.getYInput());
            int xInput = this.convSize.getXInput();
            int yInput = this.convSize.getYInput();
            int x = matrix.getX();
            int y = matrix.getY();
            for (int i = 0; i < xInput; i++) {
                for (int i2 = 0; i2 < yInput; i2++) {
                    float f = 0.0f;
                    if (i < x && i2 < y) {
                        f = matrix.getNumber(i, i2) / 2.0f;
                    }
                    matrix2.setNub(i, i2, f);
                }
            }
            arrayList.add(matrix2);
        }
        this.myUNetEncoder.setDecodeErrorMatrix(arrayList);
    }

    protected void backErrorMatrix(List<Matrix> list) throws Exception {
        List<Matrix> backAllDownConv = backAllDownConv(this.convParameter, backManyUpConv(backManyUpPooling(list), this.kerSize, this.convParameter, this.studyRate, this.activeFunction), this.studyRate, this.activeFunction, this.channelNo, this.kerSize);
        if (this.myUNetEncoder != null) {
            sendEncoderError(backAllDownConv);
        }
        if (this.beforeDecoder != null) {
            this.beforeDecoder.backErrorMatrix(backAllDownConv);
        } else {
            this.encoder.backError(backAllDownConv);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void sendFeature(long j, OutBack outBack, ThreeChannelMatrix threeChannelMatrix, List<Matrix> list, boolean z, ThreeChannelMatrix threeChannelMatrix2) throws Exception {
        if (this.deep > 1) {
            addFeatures(this.myUNetEncoder.getAfterConvMatrix(j), list, z);
        }
        List<Matrix> upConvAndPooling = upConvAndPooling(list, this.convParameter, this.channelNo, this.activeFunction, this.kerSize, !this.lastLay);
        if (this.lastLay) {
            toThreeChannelMatrix(upConvAndPooling, threeChannelMatrix, z, outBack, threeChannelMatrix2);
        } else {
            this.afterDecoder.sendFeature(j, outBack, threeChannelMatrix, upConvAndPooling, z, threeChannelMatrix2);
        }
    }

    private Matrix initUpNervePowerMatrix(Random random) throws Exception {
        int i = this.kerSize * this.kerSize;
        Matrix matrix = new Matrix(1, i);
        for (int i2 = 0; i2 < i; i2++) {
            matrix.setNub(0, i2, random.nextFloat() / this.kerSize);
        }
        return matrix;
    }

    private void initNervePowerMatrix(Random random, List<Matrix> list) throws Exception {
        int i = this.kerSize * this.kerSize;
        Matrix matrix = new Matrix(i, 1);
        for (int i2 = 0; i2 < i; i2++) {
            matrix.setNub(i2, 0, random.nextFloat() / this.kerSize);
        }
        list.add(matrix);
    }

    public UNetDecoder getAfterDecoder() {
        return this.afterDecoder;
    }

    public void setAfterDecoder(UNetDecoder uNetDecoder) {
        this.afterDecoder = uNetDecoder;
    }

    public UNetDecoder getBeforeDecoder() {
        return this.beforeDecoder;
    }

    public void setBeforeDecoder(UNetDecoder uNetDecoder) {
        this.beforeDecoder = uNetDecoder;
    }

    public UNetEncoder getEncoder() {
        return this.encoder;
    }

    public void setEncoder(UNetEncoder uNetEncoder) {
        this.encoder = uNetEncoder;
    }

    public void setMyUNetEncoder(UNetEncoder uNetEncoder) {
        this.myUNetEncoder = uNetEncoder;
    }
}
