/*
 * Decompiled with CFR 0.152.
 */
package com.agentsflex.llm.spark;

import com.agentsflex.core.document.Document;
import com.agentsflex.core.llm.BaseLlm;
import com.agentsflex.core.llm.ChatContext;
import com.agentsflex.core.llm.ChatOptions;
import com.agentsflex.core.llm.Llm;
import com.agentsflex.core.llm.LlmConfig;
import com.agentsflex.core.llm.MessageResponse;
import com.agentsflex.core.llm.StreamResponseListener;
import com.agentsflex.core.llm.client.BaseLlmClientListener;
import com.agentsflex.core.llm.client.HttpClient;
import com.agentsflex.core.llm.client.LlmClient;
import com.agentsflex.core.llm.client.LlmClientListener;
import com.agentsflex.core.llm.client.impl.WebSocketClient;
import com.agentsflex.core.llm.embedding.EmbeddingOptions;
import com.agentsflex.core.llm.response.AbstractBaseMessageResponse;
import com.agentsflex.core.llm.response.AiMessageResponse;
import com.agentsflex.core.llm.response.FunctionMessageResponse;
import com.agentsflex.core.message.AiMessage;
import com.agentsflex.core.message.FunctionMessage;
import com.agentsflex.core.parser.AiMessageParser;
import com.agentsflex.core.parser.FunctionMessageParser;
import com.agentsflex.core.prompt.Prompt;
import com.agentsflex.core.store.VectorData;
import com.agentsflex.core.util.StringUtil;
import com.agentsflex.llm.spark.SparkLlmConfig;
import com.agentsflex.llm.spark.SparkLlmUtil;
import com.alibaba.fastjson.JSONPath;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Base64;
import java.util.concurrent.CountDownLatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SparkLlm
extends BaseLlm<SparkLlmConfig> {
    private static final Logger logger = LoggerFactory.getLogger(SparkLlm.class);
    public AiMessageParser aiMessageParser = SparkLlmUtil.getAiMessageParser();
    public FunctionMessageParser functionMessageParser = SparkLlmUtil.getFunctionMessageParser();
    private final HttpClient httpClient = new HttpClient();

    public SparkLlm(SparkLlmConfig config) {
        super((LlmConfig)config);
    }

    public VectorData embed(Document document, EmbeddingOptions options) {
        String payload = SparkLlmUtil.embedPayload((SparkLlmConfig)this.config, document);
        String resp = this.httpClient.post(SparkLlmUtil.createEmbedURL((SparkLlmConfig)this.config), null, payload);
        if (StringUtil.noText((String)resp)) {
            return null;
        }
        Integer code = (Integer)JSONPath.read((String)resp, (String)"$.header.code", Integer.class);
        if (code == null || code != 0) {
            logger.error(resp);
            return null;
        }
        String text = (String)JSONPath.read((String)resp, (String)"$.payload.feature.text", String.class);
        if (StringUtil.noText((String)text)) {
            return null;
        }
        byte[] buffer = Base64.getDecoder().decode(text);
        double[] vector = new double[buffer.length / 4];
        for (int i = 0; i < vector.length; ++i) {
            int n = i * 4;
            vector[i] = ByteBuffer.wrap(buffer, n, 4).order(ByteOrder.LITTLE_ENDIAN).getFloat();
        }
        VectorData vectorData = new VectorData();
        vectorData.setVector(vector);
        return vectorData;
    }

    public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
        CountDownLatch latch = new CountDownLatch(1);
        Throwable[] failureThrowable = new Throwable[1];
        AbstractBaseMessageResponse[] messageResponse = new AbstractBaseMessageResponse[]{null};
        this.waitResponse(prompt, options, messageResponse, latch, failureThrowable);
        AbstractBaseMessageResponse response = messageResponse[0];
        if (response == null) {
            return null;
        }
        Throwable fialureThrowable = failureThrowable[0];
        if (null == response.getMessage() || fialureThrowable != null) {
            response.setError(true);
            if (fialureThrowable != null) {
                response.setErrorMessage(fialureThrowable.getMessage());
            }
        } else {
            response.setError(false);
        }
        return (R)response;
    }

    private <R extends MessageResponse<?>> void waitResponse(Prompt<R> prompt, ChatOptions options, final AbstractBaseMessageResponse<?>[] messageResponse, final CountDownLatch latch, final Throwable[] failureThrowable) {
        this.chatStream(prompt, new StreamResponseListener<R>(){

            public void onMessage(ChatContext context, R response) {
                if (response.getMessage() instanceof FunctionMessage) {
                    messageResponse[0] = (FunctionMessageResponse)response;
                } else {
                    AiMessage aiMessage = new AiMessage();
                    aiMessage.setContent(response.getMessage().getFullContent());
                    messageResponse[0] = new AiMessageResponse(aiMessage);
                }
            }

            public void onStop(ChatContext context) {
                super.onStop(context);
                latch.countDown();
            }

            public void onFailure(ChatContext context, Throwable throwable) {
                logger.error(throwable.toString(), throwable);
                failureThrowable[0] = throwable;
            }
        }, options);
        try {
            latch.await();
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
        WebSocketClient llmClient = new WebSocketClient();
        String url = SparkLlmUtil.createURL((SparkLlmConfig)this.config);
        String payload = SparkLlmUtil.promptToPayload(prompt, (SparkLlmConfig)this.config, options);
        BaseLlmClientListener clientListener = new BaseLlmClientListener((Llm)this, (LlmClient)llmClient, listener, prompt, this.aiMessageParser, this.functionMessageParser);
        llmClient.start(url, null, payload, (LlmClientListener)clientListener, this.config);
    }
}

