package org.dromara.easyai.naturalLanguage.word;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.dromara.easyai.config.SentenceConfig;
import org.dromara.easyai.entity.SentenceModel;
import org.dromara.easyai.entity.WordMatrix;
import org.dromara.easyai.entity.WordTwoVectorModel;
import org.dromara.easyai.function.Tanh;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.MatrixList;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.rnnJumpNerveEntity.MyWordFeature;
import org.dromara.easyai.rnnNerveCenter.NerveManager;
import org.dromara.easyai.rnnNerveEntity.SensoryNerve;

/* loaded from: input_file:org/dromara/easyai/naturalLanguage/word/WordEmbedding.class */
public class WordEmbedding extends MatrixOperation {
    private NerveManager nerveManager;
    private SentenceModel sentenceModel;
    private SentenceConfig config;
    private int wordVectorDimension;
    private final List<String> wordList = new ArrayList();
    private int studyTimes = 1;

    public void setStudyTimes(int i) {
        this.studyTimes = i;
    }

    public void setConfig(SentenceConfig sentenceConfig) {
        this.config = sentenceConfig;
    }

    public int getWordVectorDimension() {
        return this.wordVectorDimension;
    }

    public void init(SentenceModel sentenceModel, int i) throws Exception {
        this.wordVectorDimension = i;
        this.sentenceModel = sentenceModel;
        this.wordList.addAll(sentenceModel.getWordSet());
        this.nerveManager = new NerveManager(this.wordList.size(), i, this.wordList.size(), 1, new Tanh(), this.config.getWeStudyPoint(), this.config.getRzModel(), this.config.getWeLParam());
        this.nerveManager.init(true, false, true);
    }

    public List<String> getWordList() {
        return this.wordList;
    }

    public String getWord(int i) {
        return this.wordList.get(i);
    }

    public void insertModel(WordTwoVectorModel wordTwoVectorModel, int i) throws Exception {
        this.wordList.clear();
        this.wordVectorDimension = i;
        this.wordList.addAll(wordTwoVectorModel.getWordList());
        this.nerveManager = new NerveManager(this.wordList.size(), i, this.wordList.size(), 1, new Tanh(), this.config.getWeStudyPoint(), 0, 0.0f);
        this.nerveManager.init(true, false, true);
        this.nerveManager.insertModelParameter(wordTwoVectorModel.getModelParameter());
    }

    public MyWordFeature getEmbedding(String str, long j, boolean z) throws Exception {
        MyWordFeature myWordFeature = new MyWordFeature();
        int i = this.wordVectorDimension;
        MatrixList matrixList = null;
        for (int i2 = 0; i2 < str.length(); i2++) {
            WordMatrix wordMatrix = new WordMatrix(i);
            studyDNN(j, getID(!z ? str.substring(i2, i2 + 1) : str), 0, wordMatrix, false);
            if (matrixList == null) {
                myWordFeature.setFirstFeatureList(wordMatrix.getList());
                matrixList = new MatrixList(wordMatrix.getVector(), true);
            } else {
                matrixList.add(wordMatrix.getVector());
            }
            if (z) {
                break;
            }
        }
        myWordFeature.setFeatureMatrix(matrixList.getMatrix());
        return myWordFeature;
    }

    private void studyDNN(long j, int i, int i2, OutBack outBack, boolean z) throws Exception {
        List<SensoryNerve> sensoryNerves = this.nerveManager.getSensoryNerves();
        int size = sensoryNerves.size();
        HashMap hashMap = new HashMap();
        if (i2 > 0) {
            hashMap.put(Integer.valueOf(i2 + 1), Float.valueOf(1.0f));
        }
        for (int i3 = 0; i3 < size; i3++) {
            float f = 0.0f;
            if (i3 == i) {
                f = 1.0f;
            }
            sensoryNerves.get(i3).postMessage(j, f, z, hashMap, outBack, true, null);
        }
    }

    public WordTwoVectorModel start() throws Exception {
        List<String[]> sentenceList = this.sentenceModel.getSentenceList();
        int size = sentenceList.size();
        System.out.println("词嵌入训练启动...");
        int i = this.studyTimes * size;
        int i2 = 0;
        for (int i3 = 0; i3 < this.studyTimes; i3++) {
            for (int i4 = 0; i4 < size; i4++) {
                i2++;
                long currentTimeMillis = System.currentTimeMillis();
                study(sentenceList.get(i4));
                System.out.println("size:" + size + ",index:" + i4 + ",耗时:" + (System.currentTimeMillis() - currentTimeMillis) + ",完成度:" + String.format("%.6f", Float.valueOf((i2 / i) * 100.0f)) + "%");
            }
        }
        WordTwoVectorModel wordTwoVectorModel = new WordTwoVectorModel();
        wordTwoVectorModel.setModelParameter(this.nerveManager.getModelParameter());
        wordTwoVectorModel.setWordList(this.wordList);
        return wordTwoVectorModel;
    }

    private void study(String[] strArr) throws Exception {
        int[] iArr = new int[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            iArr[i] = getID(strArr[i]);
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            int i3 = iArr[i2];
            for (int i4 = 0; i4 < iArr.length; i4++) {
                if (i2 != i4) {
                    studyDNN(1L, i3, iArr[i4], null, true);
                }
            }
        }
    }

    public int getID(String str) {
        int i = 0;
        int size = this.wordList.size();
        int i2 = 0;
        while (true) {
            if (i2 >= size) {
                break;
            }
            if (this.wordList.get(i2).equals(str)) {
                i = i2;
                break;
            }
            i2++;
        }
        return i;
    }
}
