/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.chat;

import dev.langchain4j.agent.tool.JsonSchemaProperty;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.output.Response;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.Fail;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.Test;

public abstract class StreamingChatModelListenerIT {
    protected abstract StreamingChatLanguageModel createModel(ChatModelListener var1);

    protected abstract String modelName();

    protected Double temperature() {
        return 0.7;
    }

    protected Double topP() {
        return 1.0;
    }

    protected Integer maxTokens() {
        return 7;
    }

    protected abstract StreamingChatLanguageModel createFailingModel(ChatModelListener var1);

    protected abstract Class<? extends Exception> expectedExceptionClass();

    @Test
    void should_listen_request_and_response() {
        final AtomicReference requestReference = new AtomicReference();
        final AtomicReference responseReference = new AtomicReference();
        ChatModelListener listener = new ChatModelListener(){

            public void onRequest(ChatModelRequestContext requestContext) {
                requestReference.set(requestContext.request());
                requestContext.attributes().put("id", "12345");
            }

            public void onResponse(ChatModelResponseContext responseContext) {
                responseReference.set(responseContext.response());
                Assertions.assertThat((Object)responseContext.request()).isSameAs(requestReference.get());
                Assertions.assertThat(responseContext.attributes().get("id")).isEqualTo((Object)"12345");
            }

            public void onError(ChatModelErrorContext errorContext) {
                Fail.fail((String)"onError() must not be called");
            }
        };
        StreamingChatLanguageModel model = this.createModel(listener);
        UserMessage userMessage = UserMessage.from((String)"hello");
        ToolSpecification toolSpecification = null;
        if (this.supportsTools()) {
            toolSpecification = ToolSpecification.builder().name("add").addParameter("a", new JsonSchemaProperty[]{JsonSchemaProperty.INTEGER}).addParameter("b", new JsonSchemaProperty[]{JsonSchemaProperty.INTEGER}).build();
        }
        TestStreamingResponseHandler handler = new TestStreamingResponseHandler();
        if (this.supportsTools()) {
            model.generate(Collections.singletonList(userMessage), Collections.singletonList(toolSpecification), handler);
        } else {
            model.generate(Collections.singletonList(userMessage), handler);
        }
        AiMessage aiMessage = (AiMessage)handler.get().content();
        ChatModelRequest request = (ChatModelRequest)requestReference.get();
        Assertions.assertThat((String)request.model()).isEqualTo(this.modelName());
        Assertions.assertThat((Double)request.temperature()).isCloseTo(this.temperature(), Percentage.withPercentage((double)1.0));
        Assertions.assertThat((Double)request.topP()).isEqualTo(this.topP());
        Assertions.assertThat((Integer)request.maxTokens()).isEqualTo((Object)this.maxTokens());
        Assertions.assertThat((List)request.messages()).containsExactly((Object[])new ChatMessage[]{userMessage});
        if (this.supportsTools()) {
            Assertions.assertThat((List)request.toolSpecifications()).containsExactly((Object[])new ToolSpecification[]{toolSpecification});
        }
        ChatModelResponse response = (ChatModelResponse)responseReference.get();
        if (this.assertResponseId()) {
            Assertions.assertThat((String)response.id()).isNotBlank();
        }
        if (this.assertResponseModel()) {
            Assertions.assertThat((String)response.model()).isNotBlank();
        }
        if (this.assertTokenUsage()) {
            Assertions.assertThat((Integer)response.tokenUsage().inputTokenCount()).isGreaterThan(0);
            Assertions.assertThat((Integer)response.tokenUsage().outputTokenCount()).isGreaterThan(0);
            Assertions.assertThat((Integer)response.tokenUsage().totalTokenCount()).isGreaterThan(0);
        }
        if (this.assertFinishReason()) {
            Assertions.assertThat((Comparable)response.finishReason()).isNotNull();
        }
        Assertions.assertThat((Object)response.aiMessage()).isEqualTo((Object)aiMessage);
    }

    protected boolean supportsTools() {
        return true;
    }

    protected boolean assertResponseId() {
        return true;
    }

    protected boolean assertResponseModel() {
        return true;
    }

    protected boolean assertTokenUsage() {
        return true;
    }

    protected boolean assertFinishReason() {
        return true;
    }

    @Test
    protected void should_listen_error() throws Exception {
        final AtomicReference requestReference = new AtomicReference();
        final AtomicReference errorReference = new AtomicReference();
        ChatModelListener listener = new ChatModelListener(){

            public void onRequest(ChatModelRequestContext requestContext) {
                requestReference.set(requestContext.request());
                requestContext.attributes().put("id", "12345");
            }

            public void onResponse(ChatModelResponseContext responseContext) {
                Fail.fail((String)"onResponse() must not be called");
            }

            public void onError(ChatModelErrorContext errorContext) {
                errorReference.set(errorContext.error());
                Assertions.assertThat((Object)errorContext.request()).isSameAs(requestReference.get());
                Assertions.assertThat((Object)errorContext.partialResponse()).isNull();
                Assertions.assertThat(errorContext.attributes().get("id")).isEqualTo((Object)"12345");
            }
        };
        StreamingChatLanguageModel model = this.createFailingModel(listener);
        String userMessage = "this message will fail";
        final CompletableFuture future = new CompletableFuture();
        StreamingResponseHandler<AiMessage> handler = new StreamingResponseHandler<AiMessage>(){

            public void onNext(String token) {
                Fail.fail((String)"onNext() must not be called");
            }

            public void onError(Throwable error) {
                future.complete(error);
            }

            public void onComplete(Response<AiMessage> response) {
                Fail.fail((String)"onComplete() must not be called");
            }
        };
        model.generate(userMessage, (StreamingResponseHandler)handler);
        Throwable throwable = (Throwable)future.get(5L, TimeUnit.SECONDS);
        Assertions.assertThat((Throwable)throwable).isExactlyInstanceOf(this.expectedExceptionClass());
        Assertions.assertThat((Throwable)((Throwable)errorReference.get())).isSameAs((Object)throwable);
    }
}

