/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.interceptor.toolselection;

import com.alibaba.cloud.ai.graph.agent.interceptor.ModelCallHandler;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelInterceptor;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelRequest;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelResponse;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;

public class ToolSelectionInterceptor
extends ModelInterceptor {
    private static final Logger log = LoggerFactory.getLogger(ToolSelectionInterceptor.class);
    private static final String DEFAULT_SYSTEM_PROMPT = "Your goal is to select the most relevant tools for answering the user's query.";
    private final ChatModel selectionModel;
    private final String systemPrompt;
    private final Integer maxTools;
    private final Set<String> alwaysInclude;
    private final ObjectMapper objectMapper;

    private ToolSelectionInterceptor(Builder builder) {
        this.selectionModel = builder.selectionModel;
        this.systemPrompt = builder.systemPrompt;
        this.maxTools = builder.maxTools;
        this.alwaysInclude = builder.alwaysInclude != null ? new HashSet<String>(builder.alwaysInclude) : new HashSet();
        this.objectMapper = new ObjectMapper();
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public ModelResponse interceptModel(ModelRequest request, ModelCallHandler handler) {
        List<String> availableTools = request.getTools();
        if (availableTools == null || availableTools.isEmpty() || this.maxTools != null && availableTools.size() <= this.maxTools) {
            return handler.call(request);
        }
        String lastUserQuery = this.findLastUserMessage(request.getMessages());
        if (lastUserQuery == null) {
            log.debug("No user message found, skipping tool selection");
            return handler.call(request);
        }
        Set<String> selectedToolNames = this.selectTools(availableTools, lastUserQuery);
        log.info("Selected {} tools from {} available: {}", new Object[]{selectedToolNames.size(), availableTools.size(), selectedToolNames});
        List<String> filteredTools = availableTools.stream().filter(selectedToolNames::contains).collect(Collectors.toList());
        ModelRequest filteredRequest = ModelRequest.builder(request).tools(filteredTools).build();
        return handler.call(filteredRequest);
    }

    private String findLastUserMessage(List<Message> messages) {
        for (int i = messages.size() - 1; i >= 0; --i) {
            Message msg = messages.get(i);
            if (!(msg instanceof UserMessage)) continue;
            return msg.getText();
        }
        return null;
    }

    private Set<String> selectTools(List<String> toolNames, String userQuery) {
        try {
            StringBuilder toolList = new StringBuilder();
            for (String toolName : toolNames) {
                toolList.append("- ").append(toolName).append("\n");
            }
            String maxToolsInstruction = this.maxTools != null ? "\nIMPORTANT: List the tool names in order of relevance. Select at most " + this.maxTools + " tools." : "";
            List<UserMessage> selectionMessages = List.of(new SystemMessage(this.systemPrompt + maxToolsInstruction), new UserMessage("Available tools:\n" + toolList + "\nUser query: " + userQuery + "\n\nRespond with a JSON object containing a 'tools' array with the selected tool names: {\"tools\": [\"tool1\", \"tool2\"]}"));
            Prompt prompt = new Prompt(selectionMessages);
            ChatResponse response = this.selectionModel.call(prompt);
            String responseText = response.getResult().getOutput().getText();
            Set<String> selected = this.parseToolSelection(responseText);
            selected.addAll(this.alwaysInclude);
            if (this.maxTools != null && selected.size() > this.maxTools) {
                ArrayList<String> selectedList = new ArrayList<String>(selected);
                selected = new HashSet<String>(selectedList.subList(0, this.maxTools));
            }
            return selected;
        }
        catch (Exception e) {
            log.warn("Tool selection failed, using all tools: {}", (Object)e.getMessage());
            return new HashSet<String>(toolNames);
        }
    }

    private Set<String> parseToolSelection(String responseText) {
        try {
            ToolSelectionResponse response = (ToolSelectionResponse)this.objectMapper.readValue(responseText, ToolSelectionResponse.class);
            return new HashSet<String>(response.tools);
        }
        catch (Exception e) {
            log.debug("Failed to parse JSON, using fallback extraction");
            return new HashSet<String>();
        }
    }

    @Override
    public String getName() {
        return "ToolSelection";
    }

    public static class Builder {
        private ChatModel selectionModel;
        private String systemPrompt = "Your goal is to select the most relevant tools for answering the user's query.";
        private Integer maxTools;
        private Set<String> alwaysInclude;

        public Builder selectionModel(ChatModel selectionModel) {
            this.selectionModel = selectionModel;
            return this;
        }

        public Builder systemPrompt(String systemPrompt) {
            this.systemPrompt = systemPrompt;
            return this;
        }

        public Builder maxTools(int maxTools) {
            if (maxTools <= 0) {
                throw new IllegalArgumentException("maxTools must be > 0");
            }
            this.maxTools = maxTools;
            return this;
        }

        public Builder alwaysInclude(Set<String> alwaysInclude) {
            this.alwaysInclude = alwaysInclude;
            return this;
        }

        public Builder alwaysInclude(String ... toolNames) {
            this.alwaysInclude = new HashSet<String>(Arrays.asList(toolNames));
            return this;
        }

        public ToolSelectionInterceptor build() {
            if (this.selectionModel == null) {
                throw new IllegalStateException("selectionModel is required");
            }
            return new ToolSelectionInterceptor(this);
        }
    }

    private static class ToolSelectionResponse {
        @JsonProperty(value="tools")
        public List<String> tools;

        private ToolSelectionResponse() {
        }
    }
}

