package org.dromara.easyai.transFormer;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.dromara.easyai.function.ReLu;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.transFormer.model.LineBlockModel;
import org.dromara.easyai.transFormer.nerve.HiddenNerve;
import org.dromara.easyai.transFormer.nerve.Nerve;
import org.dromara.easyai.transFormer.nerve.OutNerve;
import org.dromara.easyai.transFormer.nerve.SoftMax;

/* loaded from: input_file:org/dromara/easyai/transFormer/LineBlock.class */
public class LineBlock {
    private final CodecBlock lastCodecBlock;
    private Matrix allError;
    private final int featureDimension;
    private final MatrixOperation matrixOperation;
    private final List<HiddenNerve> hiddenNerveList = new ArrayList();
    private final List<OutNerve> outNerveList = new ArrayList();
    private int backNumber = 0;

    public LineBlockModel getModel() throws Exception {
        LineBlockModel lineBlockModel = new LineBlockModel();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<HiddenNerve> it = this.hiddenNerveList.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getModel());
        }
        Iterator<OutNerve> it2 = this.outNerveList.iterator();
        while (it2.hasNext()) {
            arrayList2.add(it2.next().getModel());
        }
        lineBlockModel.setHiddenNervesModel(arrayList);
        lineBlockModel.setOutNervesModel(arrayList2);
        return lineBlockModel;
    }

    public void insertModel(LineBlockModel lineBlockModel) throws Exception {
        List<float[][]> hiddenNervesModel = lineBlockModel.getHiddenNervesModel();
        List<float[][]> outNervesModel = lineBlockModel.getOutNervesModel();
        for (int i = 0; i < this.hiddenNerveList.size(); i++) {
            this.hiddenNerveList.get(i).insertModel(hiddenNervesModel.get(i));
        }
        for (int i2 = 0; i2 < this.outNerveList.size(); i2++) {
            this.outNerveList.get(i2).insertModel(outNervesModel.get(i2));
        }
    }

    public LineBlock(int i, int i2, float f, CodecBlock codecBlock, boolean z, int i3, float f2, int i4, float f3) throws Exception {
        this.featureDimension = i2;
        this.lastCodecBlock = codecBlock;
        this.matrixOperation = new MatrixOperation(i4);
        SoftMax softMax = new SoftMax(this.outNerveList, z, i, i, i, f3);
        ArrayList arrayList = new ArrayList();
        for (int i5 = 0; i5 < i2; i5++) {
            HiddenNerve hiddenNerve = new HiddenNerve(i5 + 1, 1, f, new ReLu(), i2, i, this, i3, f2, i4);
            arrayList.add(hiddenNerve);
            this.hiddenNerveList.add(hiddenNerve);
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i6 = 0; i6 < i; i6++) {
            OutNerve outNerve = new OutNerve(i6 + 1, f, i2, i2, i, softMax, i3, f2, i4);
            outNerve.connectFather(arrayList);
            arrayList2.add(outNerve);
            this.outNerveList.add(outNerve);
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((Nerve) it.next()).connect(arrayList2);
        }
    }

    public void sendParameter(long j, Matrix matrix, boolean z, OutBack outBack, List<Integer> list, boolean z2) throws Exception {
        Iterator<HiddenNerve> it = this.hiddenNerveList.iterator();
        while (it.hasNext()) {
            it.next().postMessage(j, matrix, z, outBack, list, z2);
        }
    }

    public void backError(long j, Matrix matrix) throws Exception {
        this.backNumber++;
        if (this.allError == null) {
            this.allError = matrix;
        } else {
            this.allError = this.matrixOperation.add(matrix, this.allError);
        }
        if (this.backNumber == this.featureDimension) {
            this.backNumber = 0;
            Matrix sonOfMatrix = this.allError.getSonOfMatrix(0, 0, this.allError.getX(), this.allError.getY() - 1);
            this.allError = null;
            this.lastCodecBlock.backError(j, sonOfMatrix);
        }
    }
}
