/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.io.StringWriter;
import java.io.Writer;
import java.net.URL;
import java.net.URLConnection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceAllocator;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDevice;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NvidiaDockerV1CommandPlugin
implements DockerCommandPlugin {
    static final Logger LOG = LoggerFactory.getLogger(NvidiaDockerV1CommandPlugin.class);
    private Configuration conf;
    private Map<String, Set<String>> additionalCommands = null;
    private String volumeDriver = "local";
    private String DEVICE_OPTION = "--device";
    private String VOLUME_DRIVER_OPTION = "--volume-driver";
    private String MOUNT_RO_OPTION = "--volume";

    public NvidiaDockerV1CommandPlugin(Configuration conf) {
        this.conf = conf;
    }

    private String getValue(String input) throws IllegalArgumentException {
        int index = input.indexOf(61);
        if (index < 0) {
            throw new IllegalArgumentException("Failed to locate '=' from input=" + input);
        }
        return input.substring(index + 1);
    }

    private void addToCommand(String key, String value) {
        if (this.additionalCommands == null) {
            this.additionalCommands = new HashMap<String, Set<String>>();
        }
        if (!this.additionalCommands.containsKey(key)) {
            this.additionalCommands.put(key, new HashSet());
        }
        this.additionalCommands.get(key).add(value);
    }

    private void init() throws ContainerExecutionException {
        String endpoint = this.conf.get("yarn.nodemanager.resource-plugins.gpu.docker-plugin.nvidia-docker-v1.endpoint", "http://localhost:3476/v1.0/docker/cli");
        if (null == endpoint || endpoint.isEmpty()) {
            LOG.info("yarn.nodemanager.resource-plugins.gpu.docker-plugin.nvidia-docker-v1.endpoint set to empty, skip init ..");
            return;
        }
        try {
            URL url = new URL(endpoint);
            URLConnection uc = url.openConnection();
            uc.setRequestProperty("X-Requested-With", "Curl");
            StringWriter writer = new StringWriter();
            IOUtils.copy((InputStream)uc.getInputStream(), (Writer)writer, (String)"utf-8");
            String cliOptions = writer.toString();
            LOG.info("Additional docker CLI options from plugin to run GPU containers:" + cliOptions);
            for (String str : cliOptions.split(" ")) {
                if ((str = str.trim()).startsWith(this.DEVICE_OPTION)) {
                    this.addToCommand(this.DEVICE_OPTION, this.getValue(str));
                    continue;
                }
                if (str.startsWith(this.VOLUME_DRIVER_OPTION)) {
                    this.volumeDriver = this.getValue(str);
                    LOG.debug("Found volume-driver:{}", (Object)this.volumeDriver);
                    continue;
                }
                if (str.startsWith(this.MOUNT_RO_OPTION)) {
                    String mount = this.getValue(str);
                    if (!mount.endsWith(":ro")) {
                        throw new IllegalArgumentException("Should not have mount other than ro, command=" + str);
                    }
                    this.addToCommand(this.MOUNT_RO_OPTION, mount.substring(0, mount.lastIndexOf(58)));
                    continue;
                }
                throw new IllegalArgumentException("Unsupported option:" + str);
            }
        }
        catch (RuntimeException e) {
            LOG.warn("RuntimeException of " + this.getClass().getSimpleName() + " init:", (Throwable)e);
            throw new ContainerExecutionException(e);
        }
        catch (IOException e) {
            LOG.warn("IOException of " + this.getClass().getSimpleName() + " init:", (Throwable)e);
            throw new ContainerExecutionException(e);
        }
    }

    private int getGpuIndexFromDeviceName(String device) {
        String NVIDIA = "nvidia";
        int idx = device.lastIndexOf("nvidia");
        if (idx < 0) {
            return -1;
        }
        String str = device.substring(idx + "nvidia".length());
        for (int i = 0; i < str.length(); ++i) {
            if (Character.isDigit(str.charAt(i))) continue;
            return -1;
        }
        return Integer.parseInt(str);
    }

    private Set<GpuDevice> getAssignedGpus(Container container) {
        ResourceMappings resourceMappings = container.getResourceMappings();
        HashSet<GpuDevice> assignedResources = null;
        if (resourceMappings != null) {
            assignedResources = new HashSet<GpuDevice>();
            for (Serializable s : resourceMappings.getAssignedResources("yarn.io/gpu")) {
                assignedResources.add((GpuDevice)s);
            }
        }
        if (assignedResources == null || assignedResources.isEmpty()) {
            return Collections.emptySet();
        }
        return assignedResources;
    }

    @VisibleForTesting
    protected boolean requestsGpu(Container container) {
        return GpuResourceAllocator.getRequestedGpus(container.getResource()) > 0;
    }

    private boolean initializeWhenGpuRequested(Container container) throws ContainerExecutionException {
        if (!this.requestsGpu(container)) {
            return false;
        }
        if (this.additionalCommands == null) {
            this.init();
        }
        return true;
    }

    @Override
    public synchronized void updateDockerRunCommand(DockerRunCommand dockerRunCommand, Container container) throws ContainerExecutionException {
        if (!this.initializeWhenGpuRequested(container)) {
            return;
        }
        Set<GpuDevice> assignedResources = this.getAssignedGpus(container);
        if (assignedResources == null || assignedResources.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Set<String>> option : this.additionalCommands.entrySet()) {
            String key = option.getKey();
            Set<String> values = option.getValue();
            if (key.equals(this.DEVICE_OPTION)) {
                int foundGpuDevices = 0;
                for (String deviceName : values) {
                    Integer gpuIdx = this.getGpuIndexFromDeviceName(deviceName);
                    if (gpuIdx >= 0) {
                        for (GpuDevice gpuDevice : assignedResources) {
                            if (gpuDevice.getIndex() != gpuIdx.intValue()) continue;
                            ++foundGpuDevices;
                            dockerRunCommand.addDevice(deviceName, deviceName);
                        }
                        continue;
                    }
                    dockerRunCommand.addDevice(deviceName, deviceName);
                }
                if (foundGpuDevices >= assignedResources.size()) continue;
                throw new ContainerExecutionException("Cannot get all assigned Gpu devices from docker plugin output");
            }
            if (key.equals(this.MOUNT_RO_OPTION)) {
                for (String value : values) {
                    int idx = value.indexOf(58);
                    String source = value.substring(0, idx);
                    String target = value.substring(idx + 1);
                    dockerRunCommand.addReadOnlyMountLocation(source, target, true);
                }
                continue;
            }
            throw new ContainerExecutionException("Unsupported option:" + key);
        }
    }

    @Override
    public DockerVolumeCommand getCreateDockerVolumeCommand(Container container) throws ContainerExecutionException {
        if (!this.initializeWhenGpuRequested(container)) {
            return null;
        }
        String newVolumeName = null;
        Set<String> mounts = this.additionalCommands.get(this.MOUNT_RO_OPTION);
        for (String mount : mounts) {
            int idx = mount.indexOf(58);
            if (idx < 0) continue;
            String mountSource = mount.substring(0, idx);
            if (DockerVolumeCommand.VOLUME_NAME_PATTERN.matcher(mountSource).matches()) {
                newVolumeName = mountSource;
                LOG.debug("Found volume name for GPU:{}", (Object)newVolumeName);
                break;
            }
            LOG.debug("Failed to match {} to named-volume regex pattern", (Object)mountSource);
        }
        if (newVolumeName != null) {
            DockerVolumeCommand command = new DockerVolumeCommand("create");
            command.setDriverName(this.volumeDriver);
            command.setVolumeName(newVolumeName);
            return command;
        }
        return null;
    }

    @Override
    public DockerVolumeCommand getCleanupDockerVolumesCommand(Container container) throws ContainerExecutionException {
        return null;
    }
}

