/*
 * Decompiled with CFR 0.152.
 */
package com.els.modules.ai.core.nlp.algorithm;

import com.els.modules.ai.core.nlp.AiOrderCreationModelProperties;
import com.els.modules.ai.core.nlp.algorithm.ModelInfoMatchingAlgorithm;
import com.els.modules.ai.core.nlp.model.FieldSpec;
import com.els.modules.ai.core.nlp.model.ModelInfo;
import java.io.File;
import java.io.InputStream;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class ModelInfoAdvancedMatchingAlgorithm
implements ModelInfoMatchingAlgorithm {
    private static final Logger log = LoggerFactory.getLogger(ModelInfoAdvancedMatchingAlgorithm.class);
    private Word2Vec word2Vec;
    private static volatile AtomicBoolean modelLoaded = new AtomicBoolean(false);
    @Autowired
    private AiOrderCreationModelProperties aiOrderCreationModelProperties;

    public void init() {
        Thread modelLoadingThread = new Thread(this::loadWord2VecModel);
        modelLoadingThread.setDaemon(true);
        modelLoadingThread.setName("Word2Vec-Model-Loader");
        modelLoadingThread.start();
    }

    private void loadWord2VecModel() {
        try (InputStream inputStream = this.getClass().getClassLoader().getResourceAsStream(this.aiOrderCreationModelProperties.getWord2VecPath());){
            File tempFile = File.createTempFile("word2vec", ".bin");
            tempFile.deleteOnExit();
            assert (inputStream != null);
            FileUtils.copyToFile((InputStream)inputStream, (File)tempFile);
            this.word2Vec = WordVectorSerializer.readWord2VecModel((File)tempFile);
            modelLoaded.set(true);
        }
        catch (Exception exception) {
            // empty catch block
        }
    }

    @Override
    public double calculateModelMatchingScore(List<String> questionWords, ModelInfo model) {
        if (!modelLoaded.get()) {
            return this.calculateKeywordMatchingScore(questionWords, model);
        }
        double keywordScore = this.calculateKeywordMatchingScore(questionWords, model);
        double semanticScore = this.calculateSemanticSimilarityScore(questionWords, model);
        double completenessScore = this.calculateFieldCompletenessScore(questionWords, model);
        double contextScore = this.calculateContextRelevanceScore(questionWords, model);
        double totalScore = keywordScore * this.aiOrderCreationModelProperties.getKeyWordWeight() + semanticScore * this.aiOrderCreationModelProperties.getSemanticWeight() + completenessScore * this.aiOrderCreationModelProperties.getCompletenessWeight() + contextScore * this.aiOrderCreationModelProperties.getContextWeight();
        if (this.hasModelNameMatch(questionWords, model)) {
            totalScore = Math.min(1.0, totalScore + 0.5);
        }
        return totalScore;
    }

    private double calculateKeywordMatchingScore(List<String> questionWords, ModelInfo model) {
        HashSet<String> allModelTerms = new HashSet<String>();
        HashMap<String, Double> fieldWeights = new HashMap<String, Double>();
        for (FieldSpec field : model.getFieldSpecs()) {
            allModelTerms.add(field.getFieldName());
            allModelTerms.addAll(field.getFieldAlias());
            double weight = 1.0 / (double)(field.getTimes() + 1L);
            fieldWeights.put(field.getFieldName(), weight);
            for (String alias : field.getFieldAlias()) {
                fieldWeights.put(alias, weight);
            }
        }
        double totalScore = 0.0;
        int matchCount = 0;
        for (String word : questionWords) {
            if (!allModelTerms.contains(word)) continue;
            double weight = fieldWeights.getOrDefault(word, 1.0);
            totalScore += weight;
            ++matchCount;
        }
        return matchCount > 0 ? totalScore / (double)matchCount : 0.0;
    }

    private double calculateSemanticSimilarityScore(List<String> questionWords, ModelInfo model) {
        double totalSim = 0.0;
        int validPairs = 0;
        for (FieldSpec field : model.getFieldSpecs()) {
            String fieldName = field.getFieldName();
            if (!this.word2Vec.hasWord(fieldName)) continue;
            for (String qWord : questionWords) {
                double sim;
                if (!this.word2Vec.hasWord(qWord) || !((sim = this.word2Vec.similarity(fieldName, qWord)) > 0.4)) continue;
                totalSim += sim;
                ++validPairs;
                for (String alias : field.getFieldAlias()) {
                    if (!this.word2Vec.hasWord(alias) || !((sim = this.word2Vec.similarity(alias, qWord)) > 0.4)) continue;
                    totalSim += sim;
                    ++validPairs;
                }
            }
        }
        return validPairs > 0 ? totalSim / (double)validPairs : 0.0;
    }

    private double calculateFieldCompletenessScore(List<String> questionWords, ModelInfo model) {
        HashSet<String> matchedFields = new HashSet<String>();
        block0: for (FieldSpec field : model.getFieldSpecs()) {
            for (String word : questionWords) {
                if (!field.getAllNames().contains(word) && (!this.word2Vec.hasWord(field.getFieldName()) || !this.word2Vec.hasWord(word) || !(this.word2Vec.similarity(field.getFieldName(), word) > 0.7))) continue;
                matchedFields.add(field.getFieldCode());
                continue block0;
            }
        }
        return (double)matchedFields.size() / (double)model.getFieldSpecs().size();
    }

    private double calculateContextRelevanceScore(List<String> questionWords, ModelInfo model) {
        if (model.getModelDesc() == null || model.getModelDesc().isEmpty()) {
            return 0.1;
        }
        List<String> modelDescWords = Arrays.asList(model.getModelDesc().split("\\s+"));
        double contextSim = this.calculateAverageSimilarity(questionWords, modelDescWords);
        return contextSim * 0.5 + 0.5;
    }

    private double calculateAverageSimilarity(List<String> words1, List<String> words2) {
        double totalSim = 0.0;
        int pairCount = 0;
        for (String w1 : words1) {
            if (!this.word2Vec.hasWord(w1)) continue;
            for (String w2 : words2) {
                if (!this.word2Vec.hasWord(w2)) continue;
                totalSim += this.word2Vec.similarity(w1, w2);
                ++pairCount;
            }
        }
        return pairCount > 0 ? totalSim / (double)pairCount : 0.1;
    }

    private boolean hasModelNameMatch(List<String> questionWords, ModelInfo model) {
        String modelName = model.getModelName();
        return questionWords.stream().anyMatch(word -> modelName.contains((CharSequence)word) || word.contains(modelName));
    }
}

