package org.dromara.easyai.unet;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.dromara.easyai.conv.ConvCount;
import org.dromara.easyai.entity.ThreeChannelMatrix;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.nerveEntity.ConvParameter;
import org.dromara.easyai.nerveEntity.ConvSize;

/* loaded from: input_file:org/dromara/easyai/unet/UNetEncoder.class */
public class UNetEncoder extends ConvCount {
    private final ConvParameter convParameter = new ConvParameter();
    private final MatrixOperation matrixOperation = new MatrixOperation();
    private final int kerSize;
    private final float studyRate;
    private final int deep;
    private final int channelNo;
    private List<Matrix> decodeErrorMatrix;
    private final ActiveFunction activeFunction;
    private UNetEncoder afterEncoder;
    private UNetEncoder beforeEncoder;
    private UNetDecoder decoder;
    private final int xSize;
    private final int ySize;
    private final float oneStudyRate;

    public UNetEncoder(int i, int i2, int i3, ActiveFunction activeFunction, float f, int i4, int i5, float f2) throws Exception {
        Random random = new Random();
        this.xSize = i4;
        this.ySize = i5;
        this.oneStudyRate = f2;
        this.studyRate = f;
        this.kerSize = i;
        this.activeFunction = activeFunction;
        this.deep = i3;
        this.channelNo = i2;
        List<Matrix> nerveMatrixList = this.convParameter.getNerveMatrixList();
        List<ConvSize> convSizeList = this.convParameter.getConvSizeList();
        for (int i6 = 0; i6 < i2; i6++) {
            initNervePowerMatrix(random, nerveMatrixList);
            convSizeList.add(new ConvSize());
        }
        if (i3 == 1) {
            ArrayList arrayList = new ArrayList();
            for (int i7 = 0; i7 < i2; i7++) {
                ArrayList arrayList2 = new ArrayList();
                arrayList.add(arrayList2);
                for (int i8 = 0; i8 < 3; i8++) {
                    arrayList2.add(Float.valueOf(random.nextFloat() / 3));
                }
            }
            this.convParameter.setOneConvPower(arrayList);
        }
    }

    public ConvParameter getConvParameter() {
        return this.convParameter;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setDecodeErrorMatrix(List<Matrix> list) {
        this.decodeErrorMatrix = list;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Matrix> getAfterConvMatrix(long j) {
        List<Matrix> list = this.convParameter.getFeatureMap().get(Long.valueOf(j));
        this.convParameter.getFeatureMap().remove(Long.valueOf(j));
        return list;
    }

    public void sendThreeChannel(long j, OutBack outBack, ThreeChannelMatrix threeChannelMatrix, ThreeChannelMatrix threeChannelMatrix2, boolean z) throws Exception {
        if (z && threeChannelMatrix2 == null) {
            throw new Exception("训练时期望矩阵不能为空");
        }
        if (threeChannelMatrix.getX() != this.xSize && threeChannelMatrix.getY() != this.ySize) {
            throw new Exception("输入图片尺寸与初始化参数不一致");
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(threeChannelMatrix.getMatrixR());
        arrayList.add(threeChannelMatrix.getMatrixG());
        arrayList.add(threeChannelMatrix.getMatrixB());
        if (z) {
            this.convParameter.setFeatureMatrixList(arrayList);
        }
        sendMatrixList(j, outBack, threeChannelMatrix2, arrayList, z, threeChannelMatrix);
    }

    protected void sendFeature(long j, OutBack outBack, ThreeChannelMatrix threeChannelMatrix, List<Matrix> list, boolean z, ThreeChannelMatrix threeChannelMatrix2) throws Exception {
        List<Matrix> downConvAndPooling = downConvAndPooling(list, this.convParameter, this.channelNo, this.activeFunction, this.kerSize, true, j);
        if (this.afterEncoder != null) {
            this.afterEncoder.sendFeature(j, outBack, threeChannelMatrix, downConvAndPooling, z, threeChannelMatrix2);
        } else {
            this.decoder.sendFeature(j, outBack, threeChannelMatrix, downConvAndPooling, z, threeChannelMatrix2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void backError(List<Matrix> list) throws Exception {
        List<Matrix> backAllDownConv = backAllDownConv(this.convParameter, this.matrixOperation.addMatrixList(backDownPoolingByList(list, this.convParameter.getOutX(), this.convParameter.getOutY()), this.decodeErrorMatrix), this.studyRate, this.activeFunction, this.channelNo, this.kerSize);
        if (this.beforeEncoder != null) {
            this.beforeEncoder.backError(backAllDownConv);
        } else {
            backOneConvByList(backAllDownConv, this.convParameter.getFeatureMatrixList(), this.convParameter.getOneConvPower(), this.oneStudyRate, true);
        }
    }

    public void sendMatrixList(long j, OutBack outBack, ThreeChannelMatrix threeChannelMatrix, List<Matrix> list, boolean z, ThreeChannelMatrix threeChannelMatrix2) throws Exception {
        List<Matrix> downConvAndPooling = downConvAndPooling(manyOneConv(list, this.convParameter.getOneConvPower()), this.convParameter, this.channelNo, this.activeFunction, this.kerSize, true, j);
        if (this.afterEncoder != null) {
            this.afterEncoder.sendFeature(j, outBack, threeChannelMatrix, downConvAndPooling, z, threeChannelMatrix2);
        } else {
            this.decoder.sendFeature(j, outBack, threeChannelMatrix, downConvAndPooling, z, threeChannelMatrix2);
        }
    }

    private void initNervePowerMatrix(Random random, List<Matrix> list) throws Exception {
        int i = this.kerSize * this.kerSize;
        Matrix matrix = new Matrix(i, 1);
        for (int i2 = 0; i2 < i; i2++) {
            matrix.setNub(i2, 0, random.nextFloat() / this.kerSize);
        }
        list.add(matrix);
    }

    public UNetEncoder getAfterEncoder() {
        return this.afterEncoder;
    }

    public void setAfterEncoder(UNetEncoder uNetEncoder) {
        this.afterEncoder = uNetEncoder;
    }

    public UNetEncoder getBeforeEncoder() {
        return this.beforeEncoder;
    }

    public void setBeforeEncoder(UNetEncoder uNetEncoder) {
        this.beforeEncoder = uNetEncoder;
    }

    public UNetDecoder getDecoder() {
        return this.decoder;
    }

    public void setDecoder(UNetDecoder uNetDecoder) {
        this.decoder = uNetDecoder;
    }
}
