/*
 * Decompiled with CFR 0.152.
 */
package com.volcengine.ark.runtime.interceptor;

import com.volcengine.ApiClient;
import com.volcengine.ApiException;
import com.volcengine.ark.ArkApi;
import com.volcengine.ark.model.GetApiKeyRequest;
import com.volcengine.ark.model.GetApiKeyResponse;
import com.volcengine.ark.runtime.Const;
import com.volcengine.ark.runtime.exception.ArkException;
import com.volcengine.sign.Credentials;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.BiFunction;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.commons.lang.StringUtils;

public class ArkResourceStsAuthenticationInterceptor
implements Interceptor {
    private final String ak;
    private final String sk;
    private Map<String, ArkResourceStsTokenInfo> resourceStsTokens;
    private final Integer advisoryRefreshTimeout = Const.DEFAULT_ADVISORY_REFRESH_TIMEOUT;
    private final Integer mandatoryRefreshTimeout = Const.DEFAULT_MANDATORY_REFRESH_TIMEOUT;
    private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
    private final ArkApi volcClient;

    public ArkResourceStsAuthenticationInterceptor(String ak, String sk, String region) {
        ArkApi arkApi;
        Objects.requireNonNull(ak, "Ak token required");
        Objects.requireNonNull(sk, "Sk token required");
        this.ak = ak;
        this.sk = sk;
        this.resourceStsTokens = new ConcurrentHashMap<String, ArkResourceStsTokenInfo>();
        ApiClient apiClient = new ApiClient().setCredentials(Credentials.getCredentials((String)ak, (String)sk)).setRegion(region);
        this.volcClient = arkApi = new ArkApi(apiClient);
    }

    public Response intercept(Interceptor.Chain chain) throws IOException {
        Request request = chain.request();
        String requestResourceId = this.getRequestResourceId(request);
        String requestResourceType = this.getRequestResourceType(request, requestResourceId);
        String projectName = this.getProjectName(request);
        if (requestResourceType.equalsIgnoreCase("presetendpoint") && StringUtils.isBlank((String)projectName)) {
            throw new ArkException("project name is required for preset endpoint");
        }
        if (request.url().url().getPath().contains("contents/generations") || request.url().url().getPath().contains("images/generations")) {
            throw new ArkException("content generation currently does not support ak&sk authentication, use api_key instead.");
        }
        Request newRequest = chain.request().newBuilder().header("Authorization", "Bearer " + this.getResourceStsToken(requestResourceType, requestResourceId, projectName)).build();
        return chain.proceed(newRequest);
    }

    private String getRequestResourceType(Request request, String requestResourceId) {
        if (StringUtils.isNotBlank((String)request.header("X-Request-Bot"))) {
            return "bot";
        }
        if (StringUtils.isNotBlank((String)requestResourceId) && requestResourceId.startsWith("ep-m-")) {
            return "presetendpoint";
        }
        if (StringUtils.isNotBlank((String)requestResourceId) && requestResourceId.startsWith("ep-")) {
            return "endpoint";
        }
        return "presetendpoint";
    }

    private String getRequestResourceId(Request request) {
        if (StringUtils.isNotBlank((String)request.header("X-Request-Bot"))) {
            return request.header("X-Request-Bot");
        }
        return request.header("X-Request-Model");
    }

    private String getProjectName(Request request) {
        if (StringUtils.isNotBlank((String)request.header("X-Project-Name"))) {
            return request.header("X-Project-Name");
        }
        return "";
    }

    private String getResourceStsToken(String resourceType, String resourceId, String projectName) {
        this.refresh(resourceType, resourceId, projectName);
        ArkResourceStsTokenInfo tokenInfo = this.resourceStsTokens.get(this.getResourceKey(resourceType, resourceId));
        if (tokenInfo == null) {
            return "";
        }
        return tokenInfo.getToken();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void refresh(String resourceType, String resourceId, String projectName) {
        if (!this.need_refresh(resourceType, resourceId, this.advisoryRefreshTimeout)) {
            return;
        }
        if (this.lock.writeLock().tryLock()) {
            if (!this.need_refresh(resourceType, resourceId, this.advisoryRefreshTimeout)) {
                return;
            }
            try {
                boolean isMandatoryRefresh = this.need_refresh(resourceType, resourceId, this.mandatoryRefreshTimeout);
                this.protectedRefresh(resourceType, resourceId, isMandatoryRefresh, projectName);
            }
            finally {
                this.lock.writeLock().unlock();
            }
        }
        if (this.need_refresh(resourceType, resourceId, this.mandatoryRefreshTimeout)) {
            try {
                this.lock.writeLock().lock();
                if (!this.need_refresh(resourceType, resourceId, this.mandatoryRefreshTimeout)) {
                    return;
                }
                this.protectedRefresh(resourceType, resourceId, true, projectName);
            }
            finally {
                this.lock.writeLock().unlock();
            }
        }
    }

    private boolean need_refresh(String resourceType, String resourceId, Integer refresh_in) {
        ArkResourceStsTokenInfo tokenInfo = this.resourceStsTokens.get(this.getResourceKey(resourceType, resourceId));
        if (tokenInfo == null) {
            return true;
        }
        return (long)tokenInfo.getExpiredTime().intValue() - System.currentTimeMillis() / 1000L < (long)refresh_in.intValue();
    }

    private void protectedRefresh(final String resourceType, final String resourceId, final boolean isMandatory, final String projectName) {
        this.resourceStsTokens.compute(this.getResourceKey(resourceType, resourceId), new BiFunction<String, ArkResourceStsTokenInfo, ArkResourceStsTokenInfo>(){

            @Override
            public ArkResourceStsTokenInfo apply(String s, ArkResourceStsTokenInfo stringIntegerPair) {
                try {
                    ArkResourceStsTokenInfo tokenInfo = ArkResourceStsAuthenticationInterceptor.this.getToken(resourceType, resourceId, Const.DEFAULT_STS_TIMEOUT, projectName);
                    return tokenInfo;
                }
                catch (ApiException e) {
                    if (isMandatory) {
                        throw new RuntimeException(e);
                    }
                    return null;
                }
            }
        });
    }

    private ArkResourceStsTokenInfo getEndpointToken(String endpointId, Integer ttl) throws ApiException {
        return this.getToken("endpoint", endpointId, ttl, "");
    }

    private ArkResourceStsTokenInfo getToken(String resourceType, String resourceId, Integer ttl, String projectName) throws ApiException {
        if (ttl < this.advisoryRefreshTimeout * 2) {
            throw new ArkException("ttl should not be under " + this.advisoryRefreshTimeout * 2 + " seconds.");
        }
        GetApiKeyRequest r = new GetApiKeyRequest();
        r.durationSeconds(ttl);
        r.resourceType(resourceType);
        if (StringUtils.isNotBlank((String)projectName)) {
            r.projectName(projectName);
        }
        ArrayList<String> list = new ArrayList<String>();
        list.add(resourceId);
        r.resourceIds(list);
        GetApiKeyResponse response = this.volcClient.getApiKey(r);
        return new ArkResourceStsTokenInfo(response.getApiKey(), response.getExpiredTime());
    }

    private String getResourceKey(String resourceType, String resourceId) {
        return resourceType + "/" + resourceId;
    }

    public static class ArkResourceStsTokenInfo {
        private String token;
        private Integer expiredTime;

        public ArkResourceStsTokenInfo(String token, Integer expiredTime) {
            this.token = token;
            this.expiredTime = expiredTime;
        }

        public String getToken() {
            return this.token;
        }

        public void setToken(String token) {
            this.token = token;
        }

        public Integer getExpiredTime() {
            return this.expiredTime;
        }

        public void setExpiredTime(Integer expiredTime) {
            this.expiredTime = expiredTime;
        }
    }
}

