package org.dromara.easyai.naturalLanguage;

import java.util.ArrayList;
import java.util.List;
import org.dromara.easyai.config.TfConfig;
import org.dromara.easyai.entity.TalkBody;
import org.dromara.easyai.naturalLanguage.word.WordBack;
import org.dromara.easyai.transFormer.TransFormerManager;
import org.dromara.easyai.transFormer.TransWordVector;
import org.dromara.easyai.transFormer.model.TransFormerModel;
import org.dromara.easyai.transFormer.nerve.SensoryNerve;

/* loaded from: input_file:org/dromara/easyai/naturalLanguage/TalkToTalk.class */
public class TalkToTalk {
    private final TfConfig tfConfig;
    private final int maxLength;
    private final int times;
    private final TransFormerManager transFormerManager = new TransFormerManager();
    private final boolean splitAnswer;
    private final String splitWord;

    public TalkToTalk(TfConfig tfConfig) throws Exception {
        this.splitWord = tfConfig.getSplitWord();
        this.splitAnswer = (this.splitWord == null || this.splitWord.isEmpty()) ? false : true;
        this.tfConfig = tfConfig;
        this.maxLength = tfConfig.getMaxLength();
        this.times = tfConfig.getTimes();
        if (this.times <= 0) {
            throw new Exception("参数times必须大于0");
        }
    }

    private void init(List<TalkBody> list) throws Exception {
        ArrayList arrayList = new ArrayList();
        for (TalkBody talkBody : list) {
            arrayList.add(talkBody.getQuestion());
            arrayList.add(talkBody.getAnswer());
        }
        this.transFormerManager.init(this.tfConfig, arrayList);
    }

    public String getAnswer(String str, long j) throws Exception {
        SensoryNerve sensoryNerve = this.transFormerManager.getSensoryNerve();
        TransWordVector transWordVector = this.transFormerManager.getTransWordVector();
        int endID = transWordVector.getEndID();
        WordBack wordBack = new WordBack();
        StringBuilder sb = new StringBuilder();
        int i = 0;
        do {
            String str2 = null;
            if (sb.length() > 0) {
                str2 = sb.toString();
            }
            sensoryNerve.postSentence(j, str, str2, false, wordBack);
            int id = wordBack.getId();
            if (id != endID) {
                String wordByID = transWordVector.getWordByID(id);
                if (this.splitAnswer) {
                    sb.append(this.splitWord).append(wordByID);
                } else {
                    sb.append(wordByID);
                }
            }
            i++;
            if (id == endID) {
                break;
            }
        } while (i < this.maxLength);
        return sb.toString().replace(this.tfConfig.startWord, "");
    }

    public void insertModel(TransFormerModel transFormerModel) throws Exception {
        this.transFormerManager.insertModel(transFormerModel, this.tfConfig);
    }

    public TransFormerModel study(List<TalkBody> list) throws Exception {
        init(list);
        SensoryNerve sensoryNerve = this.transFormerManager.getSensoryNerve();
        int size = list.size();
        for (int i = 0; i < this.times; i++) {
            int i2 = 0;
            for (TalkBody talkBody : list) {
                i2++;
                String question = talkBody.getQuestion();
                String answer = talkBody.getAnswer();
                System.out.println("问题:" + question + ", 回答:" + answer + ",训练语句下标:" + i2 + ",总数量:" + size + ",当前次数：" + i + ",总次数:" + this.times);
                sensoryNerve.postSentence(1L, question, answer, true, null);
            }
        }
        return this.transFormerManager.getModel();
    }
}
