package org.dromara.easyai.rnnJumpNerveCenter;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.dromara.easyai.config.SentenceConfig;
import org.dromara.easyai.entity.CreatorModel;
import org.dromara.easyai.entity.SemanticsBack;
import org.dromara.easyai.entity.TalkBody;
import org.dromara.easyai.function.Tanh;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.naturalLanguage.word.WordEmbedding;
import org.dromara.easyai.rnnJumpNerveEntity.MyWordFeature;
import org.dromara.easyai.rnnJumpNerveEntity.NerveCenter;
import org.dromara.easyai.rnnJumpNerveEntity.SensoryNerve;

@Deprecated
/* loaded from: input_file:org/dromara/easyai/rnnJumpNerveCenter/CustomManager.class */
public class CustomManager {
    private final WordEmbedding wordEmbedding;
    private NerveJumpManager semanticsManager;
    private final int vectorDimension;
    private final int maxFeatureLength;
    private final float studyPoint;
    private final int minLength;
    private final int answerMaxLength;
    private int times;
    private final float param;
    private final float powerTh;
    private final int rzModel;

    public CustomManager(WordEmbedding wordEmbedding, SentenceConfig sentenceConfig) {
        this.minLength = sentenceConfig.getMinLength();
        this.wordEmbedding = wordEmbedding;
        this.vectorDimension = sentenceConfig.getQaWordVectorDimension();
        this.maxFeatureLength = sentenceConfig.getMaxWordLength();
        this.studyPoint = sentenceConfig.getWeStudyPoint();
        this.answerMaxLength = sentenceConfig.getMaxAnswerLength();
        this.powerTh = sentenceConfig.getSentenceTrustPowerTh();
        this.times = sentenceConfig.getTimes();
        this.param = sentenceConfig.getParam();
        this.rzModel = sentenceConfig.getRzModel();
        if (this.times < 1) {
            this.times = 1;
        }
    }

    public void init() throws Exception {
        this.semanticsManager = new NerveJumpManager(this.vectorDimension, this.vectorDimension, this.wordEmbedding.getWordList().size(), (this.maxFeatureLength + this.answerMaxLength) - 1, new Tanh(), false, this.studyPoint, this.rzModel, this.param);
        this.semanticsManager.setPowerTh(this.powerTh);
        this.semanticsManager.initRnn(true, true, true, true, this.maxFeatureLength);
        Iterator<NerveCenter> it = this.semanticsManager.getNerveCenterList().iterator();
        while (it.hasNext()) {
            it.next().setWordEmbedding(this.wordEmbedding);
        }
    }

    public void insertModel(CreatorModel creatorModel) throws Exception {
        this.semanticsManager.insertModelParameter(creatorModel.getSemanticsModel());
    }

    public String getAnswer(String str, long j) throws Exception {
        SemanticsBack semanticsBack = new SemanticsBack();
        if (str.length() > this.maxFeatureLength) {
            str = str.substring(0, this.maxFeatureLength);
        }
        MyWordFeature embedding = this.wordEmbedding.getEmbedding(str, j, false);
        List<Float> firstFeatureList = embedding.getFirstFeatureList();
        Matrix featureMatrix = embedding.getFeatureMatrix();
        studySemanticsNerve(j, firstFeatureList, false, null, semanticsBack, insertZero(featureMatrix, featureMatrix.getX()), getStoreys2(str.length(), this.maxFeatureLength), str.length());
        return semanticsBack.getWord();
    }

    public CreatorModel study(List<TalkBody> list) throws Exception {
        Random random = new Random();
        CreatorModel creatorModel = new CreatorModel();
        int i = this.maxFeatureLength * this.answerMaxLength * this.times;
        int size = list.size() * i;
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            System.out.println("生成模型学习完成次数：" + (i3 + 1) + "总共次数:" + i);
            for (TalkBody talkBody : list) {
                i2++;
                String question = talkBody.getQuestion();
                String answer = talkBody.getAnswer();
                if (question.length() > this.maxFeatureLength) {
                    question = question.substring(0, this.maxFeatureLength);
                }
                if (answer.length() > this.answerMaxLength) {
                    answer = answer.substring(0, this.answerMaxLength);
                }
                semanticsStudy(this.wordEmbedding.getEmbedding(question + answer, 1L, false), question, answer, 1L, random);
                System.out.println("训练进度：" + String.format("%.6f", Float.valueOf((i2 / size) * 100.0f)) + "%");
            }
        }
        creatorModel.setSemanticsModel(this.semanticsManager.getModelParameter());
        return creatorModel;
    }

    private int[] getStoreys2(int i, int i2) {
        int[] iArr = new int[i + 1];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (i3 < iArr.length - 1) {
                iArr[i3] = i3;
            } else {
                iArr[i3] = i2;
            }
        }
        return iArr;
    }

    private int[] getStoreys(int i, Random random, int i2) {
        int[] iArr;
        if (i < this.minLength) {
            iArr = new int[i];
            for (int i3 = 0; i3 < i; i3++) {
                iArr[i3] = i3 + i2;
            }
        } else {
            ArrayList arrayList = new ArrayList();
            for (int i4 = 1; i4 < i; i4++) {
                arrayList.add(Integer.valueOf(i4));
            }
            int random2 = (int) (this.minLength + (((float) Math.random()) * ((i - this.minLength) + 1)));
            iArr = new int[random2];
            if (i2 > 0) {
                iArr[0] = i2;
            }
            for (int i5 = 1; i5 < random2; i5++) {
                int nextInt = random.nextInt(arrayList.size());
                iArr[i5] = ((Integer) arrayList.get(nextInt)).intValue() + i2;
                arrayList.remove(nextInt);
            }
            Arrays.sort(iArr);
        }
        return iArr;
    }

    private int[] pushArray(int[] iArr, int i) {
        int[] iArr2 = new int[iArr.length + 1];
        for (int i2 = 0; i2 < iArr2.length; i2++) {
            if (i2 < iArr2.length - 1) {
                iArr2[i2] = iArr[i2];
            } else {
                iArr2[i2] = i;
            }
        }
        return iArr2;
    }

    private Matrix insertZero(Matrix matrix, int i) throws Exception {
        Matrix matrix2 = new Matrix(matrix.getX() + 1, matrix.getY());
        int x = matrix2.getX();
        int y = matrix2.getY();
        for (int i2 = 0; i2 < x; i2++) {
            for (int i3 = 0; i3 < y; i3++) {
                if (i2 < i) {
                    matrix2.setNub(i2, i3, matrix.getNumber(i2, i3));
                } else if (i2 > i) {
                    matrix2.setNub(i2, i3, matrix.getNumber(i2 - 1, i3));
                }
            }
        }
        return matrix2;
    }

    private void semanticsStudy(MyWordFeature myWordFeature, String str, String str2, long j, Random random) throws Exception {
        Matrix insertZero = insertZero(myWordFeature.getFeatureMatrix(), str.length());
        List<Float> firstFeatureList = myWordFeature.getFirstFeatureList();
        System.out.println("训练question:" + str + ",answer:" + str2);
        if (str.length() <= 1 || str2.isEmpty()) {
            return;
        }
        int[] storeys = getStoreys(str.length(), random, 0);
        int[] storeys2 = getStoreys(str2.length() + 1, random, this.maxFeatureLength);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < storeys2.length - 1; i++) {
            hashMap.clear();
            int i2 = storeys2[i + 1] - this.maxFeatureLength;
            storeys = pushArray(storeys, storeys2[i]);
            hashMap.put(Integer.valueOf(this.wordEmbedding.getID(str2.substring(i2 - 1, i2)) + 1), Float.valueOf(1.0f));
            studySemanticsNerve(j, firstFeatureList, true, hashMap, null, insertZero, storeys, str.length());
        }
    }

    private void studySemanticsNerve(long j, List<Float> list, boolean z, Map<Integer, Float> map, SemanticsBack semanticsBack, Matrix matrix, int[] iArr, int i) throws Exception {
        List<SensoryNerve> sensoryNerves = this.semanticsManager.getSensoryNerves();
        if (sensoryNerves.size() != list.size()) {
            throw new Exception("1size not equals,feature size:" + list.size() + ",sensorySize:" + sensoryNerves.size());
        }
        for (int i2 = 0; i2 < sensoryNerves.size(); i2++) {
            sensoryNerves.get(i2).postMessage(j, list.get(i2).floatValue(), z, map, semanticsBack, matrix, iArr, i);
        }
    }
}
