package com.alibaba.cloud.ai.service.impl;

import com.alibaba.cloud.ai.common.ModelType;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
import com.alibaba.cloud.ai.exception.NotFoundException;
import com.alibaba.cloud.ai.model.ChatModelConfig;
import com.alibaba.cloud.ai.param.ModelRunActionParam;
import com.alibaba.cloud.ai.service.ChatModelDelegate;
import com.alibaba.cloud.ai.utils.SpringApplicationUtil;
import com.alibaba.cloud.ai.vo.ActionResult;
import com.alibaba.cloud.ai.vo.ChatModelRunResult;
import com.alibaba.cloud.ai.vo.TelemetryResult;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.micrometer.tracing.Tracer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.image.ImageMessage;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImageOptionsBuilder;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;

@Service
/* loaded from: input_file:com/alibaba/cloud/ai/service/impl/ChatModelDelegateImpl.class */
public class ChatModelDelegateImpl implements ChatModelDelegate {
    private static final Logger log = LoggerFactory.getLogger(ChatModelDelegateImpl.class);
    private final Tracer tracer;
    private final ObjectMapper objectMapper = new ObjectMapper();

    public ChatModelDelegateImpl(Tracer tracer) {
        this.tracer = tracer;
    }

    @Override // com.alibaba.cloud.ai.service.ChatModelDelegate
    public List<String> listModelNames(ModelType modelType) {
        ArrayList arrayList = new ArrayList();
        if (modelType == ModelType.CHAT) {
            for (DashScopeApi.ChatModel chatModel : DashScopeApi.ChatModel.values()) {
                arrayList.add(chatModel.getModel());
            }
        } else if (modelType == ModelType.IMAGE) {
            for (DashScopeImageApi.ImageModel imageModel : DashScopeImageApi.ImageModel.values()) {
                arrayList.add(imageModel.getValue());
            }
        }
        return arrayList;
    }

    private ChatModel getChatModel(String str) {
        for (Map.Entry entry : SpringApplicationUtil.getBeans(ChatModel.class).entrySet()) {
            ChatModel chatModel = (ChatModel) entry.getValue();
            log.info("bean name:{}, bean Class:{}", entry.getKey(), chatModel.getClass());
            if (((String) entry.getKey()).equals(str)) {
                return chatModel;
            }
        }
        return null;
    }

    private ImageModel getImageModel(String str) {
        for (Map.Entry entry : SpringApplicationUtil.getBeans(ImageModel.class).entrySet()) {
            ImageModel imageModel = (ImageModel) entry.getValue();
            log.info("bean name:{}, bean Class:{}", entry.getKey(), imageModel.getClass());
            if (((String) entry.getKey()).equals(str)) {
                return imageModel;
            }
        }
        return null;
    }

    @Override // com.alibaba.cloud.ai.service.ChatModelDelegate
    public ChatModelConfig getByModelName(String str) {
        DashScopeChatModel chatModel = getChatModel(str);
        if (chatModel != null) {
            ChatModelConfig build = ChatModelConfig.builder().name(str).model(chatModel.getDefaultOptions().getModel()).modelType(ModelType.CHAT).build();
            if (chatModel.getClass() == DashScopeChatModel.class) {
                build.setChatOptions(chatModel.getDashScopeChatOptions());
            }
            return build;
        }
        DashScopeImageModel imageModel = getImageModel(str);
        if (imageModel == null) {
            log.error("can not find by bean name:{}", str);
            throw new NotFoundException();
        }
        ChatModelConfig build2 = ChatModelConfig.builder().name(str).modelType(ModelType.IMAGE).build();
        if (imageModel.getClass().equals(DashScopeImageModel.class)) {
            DashScopeImageModel dashScopeImageModel = imageModel;
            build2.setModel(dashScopeImageModel.getOptions().getModel());
            build2.setImageOptions(dashScopeImageModel.getOptions());
        }
        return build2;
    }

    @Override // com.alibaba.cloud.ai.service.ChatModelDelegate
    public ChatModelRunResult run(ModelRunActionParam modelRunActionParam) {
        String key = modelRunActionParam.getKey();
        String input = modelRunActionParam.getInput();
        DashScopeChatOptions chatOptions = modelRunActionParam.getChatOptions();
        String prompt = modelRunActionParam.getPrompt();
        DashScopeChatModel chatModel = getChatModel(key);
        if (chatModel == null) {
            log.error("can not find by bean name:{}", key);
            throw new NotFoundException();
        }
        if (chatModel.getClass() == DashScopeChatModel.class) {
            DashScopeChatModel dashScopeChatModel = chatModel;
            if (chatOptions != null) {
                try {
                    log.info("set chat options, {}", this.objectMapper.writeValueAsString(chatOptions));
                    dashScopeChatModel.setDashScopeChatOptions(chatOptions);
                } catch (Exception e) {
                    throw new RuntimeException("Failed to serialize JSON", e);
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        if (StringUtils.hasText(prompt)) {
            arrayList.add(new SystemMessage(prompt));
        }
        arrayList.add(new UserMessage(input));
        return ChatModelRunResult.builder().input(modelRunActionParam).result(ActionResult.builder().Response(chatModel.call(new Prompt(arrayList)).getResult().getOutput().getText()).build()).telemetry(TelemetryResult.builder().traceId(this.tracer.currentSpan().context().traceId()).build()).build();
    }

    @Override // com.alibaba.cloud.ai.service.ChatModelDelegate
    public String runImageGenTask(ModelRunActionParam modelRunActionParam) {
        String key = modelRunActionParam.getKey();
        String input = modelRunActionParam.getInput();
        DashScopeImageOptions imageOptions = modelRunActionParam.getImageOptions();
        String prompt = modelRunActionParam.getPrompt();
        DashScopeImageModel dashScopeImageModel = null;
        if (key != null) {
            if (imageOptions != null) {
                try {
                    log.info("set image options, {}", this.objectMapper.writeValueAsString(imageOptions));
                    dashScopeImageModel = new DashScopeImageModel(new DashScopeImageApi(key), imageOptions);
                } catch (Exception e) {
                    throw new RuntimeException("Failed to serialize JSON", e);
                }
            } else {
                imageOptions = DashScopeImageOptions.builder().withModel(dashScopeImageModel.getOptions().getModel()).build();
            }
            if (imageOptions == null) {
                imageOptions = DashScopeImageOptions.builder().withModel(DashScopeImageApi.ImageModel.WANX_V1.getValue()).build();
            }
        }
        ImageOptions build = ImageOptionsBuilder.builder().model(imageOptions.getModel()).N(imageOptions.getN()).width(imageOptions.getWidth()).height(imageOptions.getHeight()).style(imageOptions.getStyle()).build();
        ArrayList arrayList = new ArrayList();
        if (StringUtils.hasText(prompt)) {
            arrayList.add(new ImageMessage(prompt));
        }
        arrayList.add(new ImageMessage(input));
        return dashScopeImageModel.call(new ImagePrompt(arrayList, build)).getResult().getOutput().getUrl();
    }

    @Override // com.alibaba.cloud.ai.service.ChatModelDelegate
    public List<ChatModelConfig> list() {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry entry : SpringApplicationUtil.getBeans(ChatModel.class).entrySet()) {
            DashScopeChatModel dashScopeChatModel = (ChatModel) entry.getValue();
            log.info("bean name:{}, bean Class:{}", entry.getKey(), dashScopeChatModel.getClass());
            ChatModelConfig build = ChatModelConfig.builder().name((String) entry.getKey()).model(dashScopeChatModel.getDefaultOptions().getModel()).modelType(ModelType.CHAT).build();
            if (dashScopeChatModel.getClass() == DashScopeChatModel.class) {
                build.setChatOptions(dashScopeChatModel.getDashScopeChatOptions());
            }
            arrayList.add(build);
        }
        for (Map.Entry entry2 : SpringApplicationUtil.getBeans(ImageModel.class).entrySet()) {
            DashScopeImageModel dashScopeImageModel = (ImageModel) entry2.getValue();
            log.info("bean name:{}, bean Class:{}", entry2.getKey(), dashScopeImageModel.getClass());
            ChatModelConfig build2 = ChatModelConfig.builder().name((String) entry2.getKey()).modelType(ModelType.IMAGE).build();
            if (dashScopeImageModel.getClass() == DashScopeImageModel.class) {
                DashScopeImageModel dashScopeImageModel2 = dashScopeImageModel;
                build2.setModel(dashScopeImageModel2.getOptions().getModel());
                build2.setImageOptions(dashScopeImageModel2.getOptions());
            }
            arrayList.add(build2);
        }
        return arrayList;
    }

    @Override // com.alibaba.cloud.ai.service.ChatModelDelegate
    public ChatModelRunResult runImageGenTaskAndGetUrl(ModelRunActionParam modelRunActionParam) {
        return ChatModelRunResult.builder().input(modelRunActionParam).result(ActionResult.builder().Response(runImageGenTask(modelRunActionParam)).build()).telemetry(TelemetryResult.builder().traceId(this.tracer.currentSpan().context().traceId()).build()).build();
    }
}
