/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.action.NodeActionWithConfig;
import com.alibaba.cloud.ai.graph.agent.interceptor.InterceptorChain;
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolCallHandler;
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolCallRequest;
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolCallResponse;
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolInterceptor;
import com.alibaba.cloud.ai.graph.state.RemoveByHash;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;

public class AgentToolNode
implements NodeActionWithConfig {
    public static final String TOOL_NODE_NAME = "tool";
    private static final Logger logger = LoggerFactory.getLogger(AgentToolNode.class);
    private final String agentName;
    private boolean enableActingLog;
    private List<ToolCallback> toolCallbacks;
    private List<ToolInterceptor> toolInterceptors = new ArrayList<ToolInterceptor>();
    private ToolCallbackResolver toolCallbackResolver;

    public AgentToolNode(Builder builder) {
        this.agentName = builder.agentName;
        this.enableActingLog = builder.enableActingLog;
        this.toolCallbackResolver = builder.toolCallbackResolver;
        this.toolCallbacks = builder.toolCallbacks;
    }

    public void setToolCallbacks(List<ToolCallback> toolCallbacks) {
        this.toolCallbacks = toolCallbacks;
    }

    public void setToolInterceptors(List<ToolInterceptor> toolInterceptors) {
        this.toolInterceptors = toolInterceptors;
    }

    void setToolCallbackResolver(ToolCallbackResolver toolCallbackResolver) {
        this.toolCallbackResolver = toolCallbackResolver;
    }

    public List<ToolCallback> getToolCallbacks() {
        return this.toolCallbacks;
    }

    public Map<String, Object> apply(OverAllState state, RunnableConfig config) throws Exception {
        List messages = (List)state.value("messages").orElseThrow();
        Message lastMessage = (Message)messages.get(messages.size() - 1);
        HashMap<String, Object> updatedState = new HashMap<String, Object>();
        HashMap<String, Object> extraStateFromToolCall = new HashMap<String, Object>();
        if (lastMessage instanceof AssistantMessage) {
            AssistantMessage assistantMessage = (AssistantMessage)lastMessage;
            ArrayList<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<ToolResponseMessage.ToolResponse>();
            if (this.enableActingLog) {
                logger.info("[ThreadId {}] Agent {} acting with {} tools.", new Object[]{config.threadId().orElse("$default"), this.agentName, assistantMessage.getToolCalls().size()});
            }
            for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
                ToolCallResponse response = this.executeToolCallWithInterceptors(toolCall, state, config, extraStateFromToolCall);
                toolResponses.add(response.toToolResponse());
            }
            ToolResponseMessage toolResponseMessage = new ToolResponseMessage(toolResponses, Map.of());
            if (this.enableActingLog) {
                logger.info("[ThreadId {}] Agent {} acting returned: {}", new Object[]{config.threadId().orElse("$default"), this.agentName, toolResponseMessage});
            }
            updatedState.put("messages", toolResponseMessage);
        } else if (lastMessage instanceof ToolResponseMessage) {
            ToolResponseMessage toolResponseMessage = (ToolResponseMessage)lastMessage;
            if (messages.size() < 2) {
                throw new IllegalStateException("Cannot find AssistantMessage before ToolResponseMessage");
            }
            Message secondLastMessage = (Message)messages.get(messages.size() - 2);
            if (!(secondLastMessage instanceof AssistantMessage)) {
                throw new IllegalStateException("Message before ToolResponseMessage is not an AssistantMessage");
            }
            AssistantMessage assistantMessage = (AssistantMessage)secondLastMessage;
            List existingResponses = toolResponseMessage.getResponses();
            ArrayList<ToolResponseMessage.ToolResponse> allResponses = new ArrayList<ToolResponseMessage.ToolResponse>(existingResponses);
            Set executedToolNames = existingResponses.stream().map(ToolResponseMessage.ToolResponse::name).collect(Collectors.toSet());
            if (this.enableActingLog) {
                logger.info("[ThreadId {}] Agent {} acting with {} tools ({} tools provided results).", new Object[]{config.threadId().orElse("$default"), this.agentName, assistantMessage.getToolCalls().size(), existingResponses.size()});
            }
            for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
                if (executedToolNames.contains(toolCall.name())) continue;
                ToolCallResponse response = this.executeToolCallWithInterceptors(toolCall, state, config, extraStateFromToolCall);
                allResponses.add(response.toToolResponse());
            }
            ArrayList<Object> newMessages = new ArrayList<Object>();
            ToolResponseMessage newToolResponseMessage = new ToolResponseMessage(allResponses, Map.of());
            newMessages.add(newToolResponseMessage);
            newMessages.add(new RemoveByHash((Object)assistantMessage));
            updatedState.put("messages", newMessages);
            if (this.enableActingLog) {
                logger.info("[ThreadId {}] Agent {} acting successfully returned.", (Object)config.threadId().orElse("$default"), (Object)this.agentName);
                if (logger.isDebugEnabled()) {
                    logger.debug("[ThreadId {}] Agent {} acting returned: {}", new Object[]{config.threadId().orElse("$default"), this.agentName, toolResponseMessage});
                }
            }
        } else {
            throw new IllegalStateException("Last message is not an AssistantMessage or ToolResponseMessage");
        }
        updatedState.putAll(extraStateFromToolCall);
        return updatedState;
    }

    private ToolCallResponse executeToolCallWithInterceptors(AssistantMessage.ToolCall toolCall, OverAllState state, RunnableConfig config, Map<String, Object> extraStateFromToolCall) {
        ToolCallRequest request = ToolCallRequest.builder().toolCall(toolCall).context(config.metadata().orElse(new HashMap())).build();
        ToolCallHandler baseHandler = req -> {
            String result;
            ToolCallback toolCallback = this.resolve(req.getToolName());
            if (this.enableActingLog) {
                logger.info("[ThreadId {}] Agent {} acting, executing tool {}.", new Object[]{config.threadId().orElse("$default"), this.agentName, req.getToolName()});
            }
            try {
                result = toolCallback instanceof FunctionToolCallback ? toolCallback.call(req.getArguments(), new ToolContext(Map.of("_AGENT_STATE_", state, "_AGENT_CONFIG_", config, "_AGENT_STATE_FOR_UPDATE_", extraStateFromToolCall))) : toolCallback.call(req.getArguments());
                if (this.enableActingLog) {
                    logger.info("[ThreadId {}] Agent {} acting, tool {} finished", new Object[]{config.threadId().orElse("$default"), this.agentName, req.getToolName()});
                    if (logger.isDebugEnabled()) {
                        logger.debug("Tool {} returned: {}", (Object)req.getToolName(), (Object)result);
                    }
                }
            }
            catch (Exception e) {
                logger.error("[ThreadId {}] Agent {} acting, tool {} execution failed. The agent loop has ended, please use ToolRetryInterceptor to customize the retry and policy on tool failure. \n", new Object[]{config.threadId().orElse("$default"), this.agentName, req.getToolName(), e});
                throw e;
            }
            return ToolCallResponse.of(req.getToolCallId(), req.getToolName(), result);
        };
        ToolCallHandler chainedHandler = InterceptorChain.chainToolInterceptors(this.toolInterceptors, baseHandler);
        return chainedHandler.call(request);
    }

    private ToolCallback resolve(String toolName) {
        return this.toolCallbacks.stream().filter(callback -> callback.getToolDefinition().name().equals(toolName)).findFirst().orElseGet(() -> this.toolCallbackResolver.resolve(toolName));
    }

    public String getName() {
        return TOOL_NODE_NAME;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String agentName;
        private boolean enableActingLog;
        private List<ToolCallback> toolCallbacks = new ArrayList<ToolCallback>();
        private List<String> toolNames = new ArrayList<String>();
        private ToolCallbackResolver toolCallbackResolver;

        private Builder() {
        }

        public Builder agentName(String agentName) {
            this.agentName = agentName;
            return this;
        }

        public Builder enableActingLog(boolean enableActingLog) {
            this.enableActingLog = enableActingLog;
            return this;
        }

        public Builder toolCallbacks(List<ToolCallback> toolCallbacks) {
            this.toolCallbacks = toolCallbacks;
            return this;
        }

        public Builder toolNames(List<String> toolNames) {
            this.toolNames = toolNames;
            return this;
        }

        public Builder toolCallbackResolver(ToolCallbackResolver toolCallbackResolver) {
            this.toolCallbackResolver = toolCallbackResolver;
            return this;
        }

        public AgentToolNode build() {
            return new AgentToolNode(this);
        }
    }
}

