/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.interceptor.toolretry;

import com.alibaba.cloud.ai.graph.agent.interceptor.ToolCallHandler;
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolCallRequest;
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolCallResponse;
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolInterceptor;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ToolRetryInterceptor
extends ToolInterceptor {
    private static final Logger log = LoggerFactory.getLogger(ToolRetryInterceptor.class);
    private final int maxRetries;
    private final Set<String> toolNames;
    private final Predicate<Exception> retryOn;
    private final OnFailureBehavior onFailure;
    private final Function<Exception, String> errorFormatter;
    private final double backoffFactor;
    private final long initialDelayMs;
    private final long maxDelayMs;
    private final boolean jitter;

    private ToolRetryInterceptor(Builder builder) {
        this.maxRetries = builder.maxRetries;
        this.toolNames = builder.toolNames != null ? new HashSet<String>(builder.toolNames) : null;
        this.retryOn = builder.retryOn;
        this.onFailure = builder.onFailure;
        this.errorFormatter = builder.errorFormatter;
        this.backoffFactor = builder.backoffFactor;
        this.initialDelayMs = builder.initialDelayMs;
        this.maxDelayMs = builder.maxDelayMs;
        this.jitter = builder.jitter;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public ToolCallResponse interceptToolCall(ToolCallRequest request, ToolCallHandler handler) {
        String toolName = request.getToolName();
        if (this.toolNames != null && !this.toolNames.contains(toolName)) {
            return handler.call(request);
        }
        Exception lastException = null;
        for (int attempt = 0; attempt <= this.maxRetries; ++attempt) {
            try {
                return handler.call(request);
            }
            catch (Exception e) {
                lastException = e;
                if (!this.retryOn.test(e)) {
                    log.debug("Exception {} not configured for retry, re-throwing", (Object)e.getClass().getSimpleName());
                    throw e;
                }
                if (attempt == this.maxRetries) break;
                long delay = this.calculateDelay(attempt);
                log.warn("Tool '{}' failed (attempt {}/{}), retrying in {}ms: {}", new Object[]{toolName, attempt + 1, this.maxRetries + 1, delay, e.getMessage()});
                try {
                    Thread.sleep(delay);
                    continue;
                }
                catch (InterruptedException ie) {
                    Thread.currentThread().interrupt();
                    throw new RuntimeException("Retry interrupted", ie);
                }
            }
        }
        if (this.onFailure == OnFailureBehavior.RAISE) {
            throw new RuntimeException("Tool call failed after " + (this.maxRetries + 1) + " attempts", lastException);
        }
        String errorMessage = this.errorFormatter != null ? this.errorFormatter.apply(lastException) : "Tool call failed after " + (this.maxRetries + 1) + " attempts: " + lastException.getMessage();
        log.error("Tool '{}' failed after {} attempts: {}", new Object[]{toolName, this.maxRetries + 1, lastException.getMessage()});
        return ToolCallResponse.of(request.getToolCallId(), request.getToolName(), errorMessage);
    }

    private long calculateDelay(int retryNumber) {
        long delay = (long)((double)this.initialDelayMs * Math.pow(this.backoffFactor, retryNumber));
        delay = Math.min(delay, this.maxDelayMs);
        if (this.jitter) {
            double jitterFactor = 0.75 + Math.random() * 0.5;
            delay = (long)((double)delay * jitterFactor);
        }
        return delay;
    }

    @Override
    public String getName() {
        return "ToolRetry";
    }

    public static class Builder {
        private int maxRetries = 2;
        private Set<String> toolNames;
        private Predicate<Exception> retryOn = e -> true;
        private OnFailureBehavior onFailure = OnFailureBehavior.RETURN_MESSAGE;
        private Function<Exception, String> errorFormatter;
        private double backoffFactor = 2.0;
        private long initialDelayMs = 1000L;
        private long maxDelayMs = 60000L;
        private boolean jitter = true;

        public Builder maxRetries(int maxRetries) {
            if (maxRetries < 0) {
                throw new IllegalArgumentException("maxRetries must be >= 0");
            }
            this.maxRetries = maxRetries;
            return this;
        }

        public Builder toolNames(Set<String> toolNames) {
            this.toolNames = toolNames;
            return this;
        }

        public Builder toolName(String toolName) {
            if (this.toolNames == null) {
                this.toolNames = new HashSet<String>();
            }
            this.toolNames.add(toolName);
            return this;
        }

        @SafeVarargs
        public final Builder retryOn(Class<? extends Exception> ... exceptionTypes) {
            HashSet<Class<? extends Exception>> types = new HashSet<Class<? extends Exception>>(Arrays.asList(exceptionTypes));
            this.retryOn = e -> {
                for (Class type : types) {
                    if (!type.isInstance(e)) continue;
                    return true;
                }
                return false;
            };
            return this;
        }

        public Builder retryOn(Predicate<Exception> predicate) {
            this.retryOn = predicate;
            return this;
        }

        public Builder onFailure(OnFailureBehavior behavior) {
            this.onFailure = behavior;
            return this;
        }

        public Builder errorFormatter(Function<Exception, String> formatter) {
            this.errorFormatter = formatter;
            this.onFailure = OnFailureBehavior.RETURN_MESSAGE;
            return this;
        }

        public Builder backoffFactor(double backoffFactor) {
            this.backoffFactor = backoffFactor;
            return this;
        }

        public Builder initialDelay(long initialDelayMs) {
            this.initialDelayMs = initialDelayMs;
            return this;
        }

        public Builder maxDelay(long maxDelayMs) {
            this.maxDelayMs = maxDelayMs;
            return this;
        }

        public Builder jitter(boolean jitter) {
            this.jitter = jitter;
            return this;
        }

        public ToolRetryInterceptor build() {
            return new ToolRetryInterceptor(this);
        }
    }

    public static enum OnFailureBehavior {
        RAISE,
        RETURN_MESSAGE;

    }
}

