/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.hook.pii;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.agent.hook.HookPosition;
import com.alibaba.cloud.ai.graph.agent.hook.HookPositions;
import com.alibaba.cloud.ai.graph.agent.hook.JumpTo;
import com.alibaba.cloud.ai.graph.agent.hook.ModelHook;
import com.alibaba.cloud.ai.graph.agent.hook.pii.PIIDetectionException;
import com.alibaba.cloud.ai.graph.agent.hook.pii.PIIDetector;
import com.alibaba.cloud.ai.graph.agent.hook.pii.PIIDetectors;
import com.alibaba.cloud.ai.graph.agent.hook.pii.PIIMatch;
import com.alibaba.cloud.ai.graph.agent.hook.pii.PIIType;
import com.alibaba.cloud.ai.graph.agent.hook.pii.RedactionStrategy;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;

@HookPositions(value={HookPosition.BEFORE_MODEL, HookPosition.AFTER_MODEL})
public class PIIDetectionHook
extends ModelHook {
    private final PIIType piiType;
    private final RedactionStrategy strategy;
    private final PIIDetector detector;
    private final boolean applyToInput;
    private final boolean applyToOutput;
    private final boolean applyToToolResults;

    private PIIDetectionHook(Builder builder) {
        this.piiType = builder.piiType;
        this.strategy = builder.strategy;
        this.detector = builder.detector != null ? builder.detector : this.getDefaultDetector(this.piiType);
        this.applyToInput = builder.applyToInput;
        this.applyToOutput = builder.applyToOutput;
        this.applyToToolResults = builder.applyToToolResults;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public CompletableFuture<Map<String, Object>> beforeModel(OverAllState state, RunnableConfig config) {
        List messages = state.value("messages").orElse(List.of());
        ArrayList<Message> processedMessages = new ArrayList<Message>();
        boolean hasChanges = false;
        for (Message message : messages) {
            Message processed = this.processMessage(message);
            processedMessages.add(processed);
            if (processed == message) continue;
            hasChanges = true;
        }
        if (hasChanges) {
            HashMap<String, ArrayList<Message>> updates = new HashMap<String, ArrayList<Message>>();
            updates.put("messages", processedMessages);
            return CompletableFuture.completedFuture(updates);
        }
        return CompletableFuture.completedFuture(Map.of());
    }

    @Override
    public CompletableFuture<Map<String, Object>> afterModel(OverAllState state, RunnableConfig config) {
        if (!this.applyToOutput) {
            return CompletableFuture.completedFuture(Map.of());
        }
        List messages = state.value("messages").orElse(List.of());
        if (messages.isEmpty()) {
            return CompletableFuture.completedFuture(Map.of());
        }
        AssistantMessage aiMessage = null;
        int lastIndex = -1;
        for (int i = messages.size() - 1; i >= 0; --i) {
            AssistantMessage am;
            Object e = messages.get(i);
            if (!(e instanceof AssistantMessage)) continue;
            aiMessage = am = (AssistantMessage)e;
            lastIndex = i;
            break;
        }
        if (aiMessage == null) {
            return CompletableFuture.completedFuture(Map.of());
        }
        String content = aiMessage.getText();
        if (content == null || content.isEmpty()) {
            return CompletableFuture.completedFuture(Map.of());
        }
        ProcessResult result = this.processText(content);
        if (!result.hasMatches) {
            return CompletableFuture.completedFuture(Map.of());
        }
        if (result.hasMatches && this.strategy == RedactionStrategy.BLOCK) {
            throw new PIIDetectionException(this.piiType.name(), result.matches);
        }
        if (result.redactedText.equals(content)) {
            return CompletableFuture.completedFuture(Map.of());
        }
        AssistantMessage updatedMessage = new AssistantMessage(result.redactedText, aiMessage.getMetadata(), aiMessage.getToolCalls(), aiMessage.getMedia());
        ArrayList<AssistantMessage> updatedMessages = new ArrayList<AssistantMessage>(messages);
        updatedMessages.set(lastIndex, updatedMessage);
        HashMap<String, ArrayList<AssistantMessage>> updates = new HashMap<String, ArrayList<AssistantMessage>>();
        updates.put("messages", updatedMessages);
        return CompletableFuture.completedFuture(updates);
    }

    private Message processMessage(Message message) {
        if (this.applyToInput && message instanceof UserMessage) {
            return this.processContent((UserMessage)message);
        }
        if (this.applyToOutput && message instanceof AssistantMessage) {
            return this.processContent((AssistantMessage)message);
        }
        if (this.applyToToolResults && message instanceof ToolResponseMessage) {
            return this.processToolResponse((ToolResponseMessage)message);
        }
        return message;
    }

    private UserMessage processContent(UserMessage message) {
        String content = message.getText();
        ProcessResult result = this.processText(content);
        if (result.hasMatches && this.strategy == RedactionStrategy.BLOCK) {
            throw new PIIDetectionException(this.piiType.name(), result.matches);
        }
        if (result.redactedText.equals(content)) {
            return message;
        }
        return UserMessage.builder().text(result.redactedText).metadata(message.getMetadata()).build();
    }

    private AssistantMessage processContent(AssistantMessage message) {
        String content = message.getText();
        ProcessResult result = this.processText(content);
        if (result.hasMatches && this.strategy == RedactionStrategy.BLOCK) {
            throw new PIIDetectionException(this.piiType.name(), result.matches);
        }
        if (result.redactedText.equals(content)) {
            return message;
        }
        return new AssistantMessage(result.redactedText, message.getMetadata(), message.getToolCalls(), message.getMedia());
    }

    private ToolResponseMessage processToolResponse(ToolResponseMessage message) {
        ArrayList<ToolResponseMessage.ToolResponse> responses = new ArrayList<ToolResponseMessage.ToolResponse>();
        boolean hasChanges = false;
        for (ToolResponseMessage.ToolResponse response : message.getResponses()) {
            String content = response.responseData();
            ProcessResult result = this.processText(content);
            if (result.hasMatches && this.strategy == RedactionStrategy.BLOCK) {
                throw new PIIDetectionException(this.piiType.name(), result.matches);
            }
            if (!result.redactedText.equals(content)) {
                responses.add(new ToolResponseMessage.ToolResponse(response.id(), response.name(), result.redactedText));
                hasChanges = true;
                continue;
            }
            responses.add(response);
        }
        return hasChanges ? new ToolResponseMessage(responses, message.getMetadata()) : message;
    }

    private ProcessResult processText(String text) {
        List<PIIMatch> matches = this.detector.detect(text);
        if (matches.isEmpty()) {
            return new ProcessResult(text, false, matches);
        }
        String redacted = this.applyStrategy(text, matches);
        return new ProcessResult(redacted, true, matches);
    }

    private String applyStrategy(String text, List<PIIMatch> matches) {
        if (matches.isEmpty()) {
            return text;
        }
        StringBuilder result = new StringBuilder();
        int lastEnd = 0;
        matches.sort(Comparator.comparingInt(m -> m.start));
        for (PIIMatch match : matches) {
            result.append(text, lastEnd, match.start);
            switch (this.strategy) {
                case REDACT: {
                    result.append("[REDACTED_").append(this.piiType.name()).append("]");
                    break;
                }
                case MASK: {
                    result.append(this.maskValue(match.value));
                    break;
                }
                case HASH: {
                    result.append(this.hashValue(match.value));
                    break;
                }
            }
            lastEnd = match.end;
        }
        result.append(text.substring(lastEnd));
        return result.toString();
    }

    private String maskValue(String value) {
        if (value.length() <= 4) {
            return "****";
        }
        int visibleChars = 4;
        String masked = "*".repeat(value.length() - visibleChars);
        return masked + value.substring(value.length() - visibleChars);
    }

    private String hashValue(String value) {
        int hash = value.hashCode();
        return String.format("<%s_hash:%08x>", this.piiType.name().toLowerCase(), hash);
    }

    private PIIDetector getDefaultDetector(PIIType type) {
        switch (type) {
            case EMAIL: {
                return PIIDetectors.emailDetector();
            }
            case CREDIT_CARD: {
                return PIIDetectors.creditCardDetector();
            }
            case IP: {
                return PIIDetectors.ipDetector();
            }
            case MAC_ADDRESS: {
                return PIIDetectors.macAddressDetector();
            }
            case URL: {
                return PIIDetectors.urlDetector();
            }
        }
        throw new IllegalArgumentException("No default detector for PII type: " + type);
    }

    @Override
    public String getName() {
        return "PIIDetection[" + this.piiType.name() + "]";
    }

    @Override
    public List<JumpTo> canJumpTo() {
        return List.of();
    }

    public static class Builder {
        private PIIType piiType;
        private RedactionStrategy strategy = RedactionStrategy.REDACT;
        private PIIDetector detector;
        private boolean applyToInput = true;
        private boolean applyToOutput = false;
        private boolean applyToToolResults = false;

        public Builder piiType(PIIType piiType) {
            this.piiType = piiType;
            return this;
        }

        public Builder strategy(RedactionStrategy strategy) {
            this.strategy = strategy;
            return this;
        }

        public Builder detector(PIIDetector detector) {
            this.detector = detector;
            return this;
        }

        public Builder applyToInput(boolean applyToInput) {
            this.applyToInput = applyToInput;
            return this;
        }

        public Builder applyToOutput(boolean applyToOutput) {
            this.applyToOutput = applyToOutput;
            return this;
        }

        public Builder applyToToolResults(boolean applyToToolResults) {
            this.applyToToolResults = applyToToolResults;
            return this;
        }

        public PIIDetectionHook build() {
            if (this.piiType == null) {
                throw new IllegalArgumentException("piiType must be specified");
            }
            return new PIIDetectionHook(this);
        }
    }

    private static class ProcessResult {
        final String redactedText;
        final boolean hasMatches;
        final List<PIIMatch> matches;

        ProcessResult(String redactedText, boolean hasMatches, List<PIIMatch> matches) {
            this.redactedText = redactedText;
            this.hasMatches = hasMatches;
            this.matches = matches;
        }
    }
}

