/*
 * Decompiled with CFR 0.152.
 */
package xyz.erupt.ai.call;

import com.google.gson.reflect.TypeToken;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.core.annotation.Order;
import org.springframework.core.type.filter.AssignableTypeFilter;
import org.springframework.core.type.filter.TypeFilter;
import org.springframework.stereotype.Component;
import xyz.erupt.ai.annotation.AiParam;
import xyz.erupt.ai.call.AiFunctionCall;
import xyz.erupt.ai.call.ParamPromptTemplate;
import xyz.erupt.ai.constants.ResponseFormat;
import xyz.erupt.ai.core.LlmCore;
import xyz.erupt.ai.core.LlmRequest;
import xyz.erupt.ai.model.LLM;
import xyz.erupt.ai.pojo.ChatCompletionMessage;
import xyz.erupt.ai.util.MarkDownUtil;
import xyz.erupt.core.config.GsonFactory;
import xyz.erupt.core.exception.EruptWebApiRuntimeException;
import xyz.erupt.core.service.EruptApplication;
import xyz.erupt.core.util.EruptSpringUtil;

@Component
@Order(value=100)
public class AiFunctionManager
implements ApplicationRunner {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(AiFunctionManager.class);
    private static final Map<String, AiFunctionCall> aiFunctions = new HashMap<String, AiFunctionCall>();

    public void run(ApplicationArguments args) {
        EruptSpringUtil.scannerPackage((String[])EruptApplication.getScanPackage(), (TypeFilter[])new TypeFilter[]{new AssignableTypeFilter(AiFunctionCall.class)}, clazz -> aiFunctions.put(clazz.getSimpleName(), (AiFunctionCall)EruptSpringUtil.getBean((Class)clazz)));
    }

    public String getFunctionCallPrompt() {
        StringBuilder sb = new StringBuilder("\u4e0b\u9762\u662f\u4e00\u7ec4 Function Call \u7684\u6620\u5c04\uff0c\u6839\u636e\u60c5\u51b5\u51b3\u5b9a\u662f\u5426\u8c03\u7528\uff0c\u5426\u5219\u5ffd\u7565\u8fd9\u6bb5\u63d0\u793a\u8bcd\n");
        for (Map.Entry<String, AiFunctionCall> entry : aiFunctions.entrySet()) {
            sb.append("- \u5982\u679c\u7528\u6237\u95ee\uff1a").append(entry.getValue().description()).append("\uff0c\u5c31\u53ea\u56de\u590d\uff1a").append(entry.getKey()).append("\n");
        }
        return sb.toString();
    }

    public boolean exist(String key) {
        return aiFunctions.containsKey(key);
    }

    public String call(String key, LLM llm, String userMessage, List<ChatCompletionMessage> userContext) {
        AiFunctionCall aiFunctionCall = aiFunctions.get(key);
        HashMap<String, Field> params = new HashMap<String, Field>();
        for (Field field : aiFunctionCall.getClass().getDeclaredFields()) {
            Optional.ofNullable(field.getAnnotation(AiParam.class)).ifPresent(it -> params.put(field.getName(), field));
        }
        if (params.isEmpty()) {
            return aiFunctionCall.call(userMessage);
        }
        Map<String, ParamPromptTemplate> promptTemplateMap = AiFunctionManager.getStringParamPromptTemplateMap(params);
        StringBuilder prompt = new StringBuilder();
        prompt.append(userMessage).append("\n\n");
        prompt.append("\u6839\u636e\u4e0a\u9762\u7684\u5185\u5bb9\uff0c\u81ea\u52a8\u8bc6\u522b\u5e76\u586b\u5145\u4e0b\u9762JSON\u7684val\u5b57\u6bb5\uff0c\u6b64JSON\u4e2d\u7684\u6bcf\u4e2avalue\u90fd\u662f\u5177\u4f53\u7684\u751f\u6210\u8981\u6c42\uff0c\u5c06\u4e0d\u540ckey\u7684\u8bc6\u522b\u7ed3\u679c\u653e\u5230\u5bf9\u5e94val\u5b57\u6bb5\u5185\n");
        prompt.append("\u8bf7\u4e25\u683c\u6309\u7167\u4ee5\u4e0bJSON\u683c\u5f0f\u8fd4\u56de\uff0c\u4e0d\u8981\u8fd4\u56de\u5176\u4ed6\u4efb\u4f55\u591a\u4f59\u7684\u5185\u5bb9\u6216\u89e3\u91ca\uff0c\u8bf7\u786e\u4fdd\u53ea\u8fd4\u56de\u7eafJSON\uff1a\n\n");
        prompt.append(GsonFactory.getGson().toJson(promptTemplateMap));
        LlmRequest llmRequest = llm.toLlmRequest();
        llmRequest.setResponseFormat(ResponseFormat.json_object);
        String llmRes = LlmCore.getLLM(llm).chat(llm.toLlmRequest(), prompt.toString(), userContext).getMessageStr();
        llmRes = MarkDownUtil.extractCodeBlock(llmRes);
        try {
            Map res = (Map)GsonFactory.getGson().fromJson(llmRes, new TypeToken<Map<String, ParamPromptTemplate>>(){}.getType());
            for (Map.Entry entry : res.entrySet()) {
                Field field = aiFunctionCall.getClass().getDeclaredField((String)entry.getKey());
                field.setAccessible(true);
                field.set(aiFunctionCall, ((ParamPromptTemplate)entry.getValue()).getVal());
                field.setAccessible(false);
            }
        }
        catch (Exception e) {
            throw new EruptWebApiRuntimeException("Function Call param error: " + e.getMessage() + "\u2192 \n\n" + llmRes);
        }
        return aiFunctionCall.call(prompt.toString());
    }

    private static Map<String, ParamPromptTemplate> getStringParamPromptTemplateMap(Map<String, Field> params) {
        HashMap<String, ParamPromptTemplate> promptTemplateMap = new HashMap<String, ParamPromptTemplate>();
        for (Map.Entry<String, Field> entry : params.entrySet()) {
            AiParam aiFuncParam = entry.getValue().getAnnotation(AiParam.class);
            ParamPromptTemplate promptTemplate = new ParamPromptTemplate();
            promptTemplate.setDescription(aiFuncParam.description());
            promptTemplate.setRequired(aiFuncParam.required());
            promptTemplate.setType(entry.getValue().getType().getSimpleName());
            promptTemplateMap.put(entry.getKey(), promptTemplate);
        }
        return promptTemplateMap;
    }

    @Generated
    public static Map<String, AiFunctionCall> getAiFunctions() {
        return aiFunctions;
    }
}

