package org.dromara.easyai.transFormer;

import java.util.ArrayList;
import java.util.List;
import org.dromara.easyai.config.TfConfig;
import org.dromara.easyai.transFormer.model.CodecBlockModel;
import org.dromara.easyai.transFormer.model.TransFormerModel;
import org.dromara.easyai.transFormer.model.TransWordVectorModel;
import org.dromara.easyai.transFormer.nerve.SensoryNerve;

/* loaded from: input_file:org/dromara/easyai/transFormer/TransFormerManager.class */
public class TransFormerManager {
    private final List<CodecBlock> encoderBlocks = new ArrayList();
    private final List<CodecBlock> decoderBlocks = new ArrayList();
    private SensoryNerve sensoryNerve;
    private FirstDecoderBlock firstDecoderBlock;
    private LineBlock lineBlock;
    private TransWordVector transWordVector;

    public TransWordVector getTransWordVector() {
        return this.transWordVector;
    }

    public SensoryNerve getSensoryNerve() {
        return this.sensoryNerve;
    }

    public TransFormerModel getModel() throws Exception {
        TransFormerModel transFormerModel = new TransFormerModel();
        transFormerModel.setTransWordVectorModel(this.transWordVector.getModel());
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < this.encoderBlocks.size(); i++) {
            arrayList.add(this.encoderBlocks.get(i).getModel());
            arrayList2.add(this.decoderBlocks.get(i).getModel());
        }
        transFormerModel.setEncoderBlockModels(arrayList);
        transFormerModel.setDecoderBlockModels(arrayList2);
        transFormerModel.setFirstDecoderBlockModel(this.firstDecoderBlock.getModel());
        transFormerModel.setLineBlockModel(this.lineBlock.getModel());
        return transFormerModel;
    }

    public void insertModel(TransFormerModel transFormerModel, TfConfig tfConfig) throws Exception {
        init(tfConfig, null, transFormerModel.getTransWordVectorModel());
        List<CodecBlockModel> encoderBlockModels = transFormerModel.getEncoderBlockModels();
        List<CodecBlockModel> decoderBlockModels = transFormerModel.getDecoderBlockModels();
        int min = Math.min(this.encoderBlocks.size(), encoderBlockModels.size());
        for (int i = 0; i < min; i++) {
            this.encoderBlocks.get(i).insertModel(encoderBlockModels.get(i));
            this.decoderBlocks.get(i).insertModel(decoderBlockModels.get(i));
        }
        this.firstDecoderBlock.insertModel(transFormerModel.getFirstDecoderBlockModel());
        this.lineBlock.insertModel(transFormerModel.getLineBlockModel());
    }

    public void init(TfConfig tfConfig, List<String> list) throws Exception {
        if (this.transWordVector == null) {
            init(tfConfig, list, null);
        } else {
            this.transWordVector.init(list);
        }
    }

    private void init(TfConfig tfConfig, List<String> list, TransWordVectorModel transWordVectorModel) throws Exception {
        this.transWordVector = new TransWordVector(tfConfig);
        int typeNumber = tfConfig.getTypeNumber();
        if (transWordVectorModel == null) {
            this.transWordVector.init(list);
        } else {
            this.transWordVector.insertModel(transWordVectorModel);
        }
        if (tfConfig.isNorm()) {
            typeNumber = this.transWordVector.getWordSize();
        }
        int multiNumber = tfConfig.getMultiNumber();
        int featureDimension = tfConfig.getFeatureDimension();
        if (featureDimension % 2 != 0) {
            throw new Exception("TransFormer 词向量维度必须为偶数");
        }
        int allDepth = tfConfig.getAllDepth();
        float studyRate = tfConfig.getStudyRate();
        boolean isShowLog = tfConfig.isShowLog();
        int regularModel = tfConfig.getRegularModel();
        float regular = tfConfig.getRegular();
        if (multiNumber <= 1 || featureDimension <= 0 || allDepth <= 0 || typeNumber <= 1) {
            throw new Exception("param is null,typeNumber:" + typeNumber + ",featureDimension:" + featureDimension);
        }
        for (int i = 0; i < allDepth; i++) {
            this.encoderBlocks.add(new CodecBlock(multiNumber, featureDimension, studyRate, i + 1, true, regularModel, regular, tfConfig.getCoreNumber(), this.transWordVector));
        }
        CodecBlock codecBlock = this.encoderBlocks.get(this.encoderBlocks.size() - 1);
        for (int i2 = 0; i2 < allDepth; i2++) {
            CodecBlock codecBlock2 = new CodecBlock(multiNumber, featureDimension, studyRate, i2 + 2, false, regularModel, regular, tfConfig.getCoreNumber(), this.transWordVector);
            codecBlock2.setLastEncoderBlock(codecBlock);
            this.decoderBlocks.add(codecBlock2);
        }
        CodecBlock codecBlock3 = this.decoderBlocks.get(this.decoderBlocks.size() - 1);
        connectCodecBlock(this.encoderBlocks);
        connectCodecBlock(this.decoderBlocks);
        this.lineBlock = new LineBlock(typeNumber, featureDimension, studyRate, codecBlock3, isShowLog, regularModel, regular, tfConfig.getCoreNumber(), tfConfig.getTimePunValue());
        codecBlock3.setLineBlock(this.lineBlock);
        this.firstDecoderBlock = new FirstDecoderBlock(multiNumber, featureDimension, studyRate, this.decoderBlocks.get(0), tfConfig.getCoreNumber(), this.transWordVector);
        this.firstDecoderBlock.setLastEncoderBlock(codecBlock);
        this.decoderBlocks.get(0).setFirstDecoderBlock(this.firstDecoderBlock);
        this.sensoryNerve = new SensoryNerve(this.encoderBlocks.get(0), this.firstDecoderBlock, this.transWordVector);
    }

    private void connectCodecBlock(List<CodecBlock> list) {
        int size = list.size();
        for (int i = 0; i < size - 1; i++) {
            CodecBlock codecBlock = list.get(i);
            CodecBlock codecBlock2 = list.get(i + 1);
            codecBlock.setBeforeEncoderBlock(codecBlock2);
            codecBlock2.setAfterEncoderBlock(codecBlock);
        }
    }
}
