package com.agentsflex.llm.spark;

import com.agentsflex.core.document.Document;
import com.agentsflex.core.llm.BaseLlm;
import com.agentsflex.core.llm.ChatContext;
import com.agentsflex.core.llm.ChatOptions;
import com.agentsflex.core.llm.MessageResponse;
import com.agentsflex.core.llm.StreamResponseListener;
import com.agentsflex.core.llm.client.BaseLlmClientListener;
import com.agentsflex.core.llm.client.HttpClient;
import com.agentsflex.core.llm.client.impl.WebSocketClient;
import com.agentsflex.core.llm.embedding.EmbeddingOptions;
import com.agentsflex.core.llm.response.AbstractBaseMessageResponse;
import com.agentsflex.core.llm.response.AiMessageResponse;
import com.agentsflex.core.llm.response.FunctionMessageResponse;
import com.agentsflex.core.message.AiMessage;
import com.agentsflex.core.message.FunctionMessage;
import com.agentsflex.core.parser.AiMessageParser;
import com.agentsflex.core.parser.FunctionMessageParser;
import com.agentsflex.core.prompt.Prompt;
import com.agentsflex.core.store.VectorData;
import com.agentsflex.core.util.StringUtil;
import com.alibaba.fastjson.JSONPath;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Base64;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/agentsflex/llm/spark/SparkLlm.class */
public class SparkLlm extends BaseLlm<SparkLlmConfig> {
    private static final Logger logger = LoggerFactory.getLogger(SparkLlm.class);
    public AiMessageParser aiMessageParser;
    public FunctionMessageParser functionMessageParser;
    private final HttpClient httpClient;

    public SparkLlm(SparkLlmConfig sparkLlmConfig) {
        super(sparkLlmConfig);
        this.aiMessageParser = SparkLlmUtil.getAiMessageParser();
        this.functionMessageParser = SparkLlmUtil.getFunctionMessageParser();
        this.httpClient = new HttpClient();
    }

    public VectorData embed(Document document, EmbeddingOptions embeddingOptions) {
        String post = this.httpClient.post(SparkLlmUtil.createEmbedURL((SparkLlmConfig) this.config), (Map) null, SparkLlmUtil.embedPayload((SparkLlmConfig) this.config, document));
        if (StringUtil.noText(post)) {
            return null;
        }
        Integer num = (Integer) JSONPath.read(post, "$.header.code", Integer.class);
        if (num == null || num.intValue() != 0) {
            logger.error(post);
            return null;
        }
        String str = (String) JSONPath.read(post, "$.payload.feature.text", String.class);
        if (StringUtil.noText(str)) {
            return null;
        }
        double[] dArr = new double[Base64.getDecoder().decode(str).length / 4];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = ByteBuffer.wrap(r0, i * 4, 4).order(ByteOrder.LITTLE_ENDIAN).getFloat();
        }
        VectorData vectorData = new VectorData();
        vectorData.setVector(dArr);
        return vectorData;
    }

    public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions chatOptions) {
        Throwable[] thArr = new Throwable[1];
        AbstractBaseMessageResponse<?>[] abstractBaseMessageResponseArr = {null};
        waitResponse(prompt, chatOptions, abstractBaseMessageResponseArr, new CountDownLatch(1), thArr);
        AbstractBaseMessageResponse<?> abstractBaseMessageResponse = abstractBaseMessageResponseArr[0];
        if (abstractBaseMessageResponse == null) {
            return null;
        }
        Throwable th = thArr[0];
        if (null == abstractBaseMessageResponse.getMessage() || th != null) {
            abstractBaseMessageResponse.setError(true);
            if (th != null) {
                abstractBaseMessageResponse.setErrorMessage(th.getMessage());
            }
        } else {
            abstractBaseMessageResponse.setError(false);
        }
        return abstractBaseMessageResponse;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <R extends MessageResponse<?>> void waitResponse(Prompt<R> prompt, ChatOptions chatOptions, final AbstractBaseMessageResponse<?>[] abstractBaseMessageResponseArr, final CountDownLatch countDownLatch, final Throwable[] thArr) {
        chatStream(prompt, new StreamResponseListener<R>() { // from class: com.agentsflex.llm.spark.SparkLlm.1
            /* JADX WARN: Incorrect types in method signature: (Lcom/agentsflex/core/llm/ChatContext;TR;)V */
            public void onMessage(ChatContext chatContext, MessageResponse messageResponse) {
                if (messageResponse.getMessage() instanceof FunctionMessage) {
                    abstractBaseMessageResponseArr[0] = (FunctionMessageResponse) messageResponse;
                    return;
                }
                AiMessage aiMessage = new AiMessage();
                aiMessage.setContent(messageResponse.getMessage().getFullContent());
                abstractBaseMessageResponseArr[0] = new AiMessageResponse(aiMessage);
            }

            public void onStop(ChatContext chatContext) {
                super.onStop(chatContext);
                countDownLatch.countDown();
            }

            public void onFailure(ChatContext chatContext, Throwable th) {
                SparkLlm.logger.error(th.toString(), th);
                thArr[0] = th;
            }
        }, chatOptions);
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> streamResponseListener, ChatOptions chatOptions) {
        WebSocketClient webSocketClient = new WebSocketClient();
        webSocketClient.start(SparkLlmUtil.createURL((SparkLlmConfig) this.config), (Map) null, SparkLlmUtil.promptToPayload(prompt, (SparkLlmConfig) this.config, chatOptions), new BaseLlmClientListener(this, webSocketClient, streamResponseListener, prompt, this.aiMessageParser, this.functionMessageParser), this.config);
    }
}
