/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed;

import java.net.InterfaceAddress;
import java.net.NetworkInterface;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.enums.NodeRole;
import org.nd4j.parameterserver.distributed.logic.Storage;
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
import org.nd4j.parameterserver.distributed.logic.sequence.BasicSequenceProvider;
import org.nd4j.parameterserver.distributed.logic.storage.WordVectorStorage;
import org.nd4j.parameterserver.distributed.messages.Frame;
import org.nd4j.parameterserver.distributed.messages.MeaningfulMessage;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.messages.VoidMessage;
import org.nd4j.parameterserver.distributed.messages.requests.InitializationRequestMessage;
import org.nd4j.parameterserver.distributed.messages.requests.VectorRequestMessage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer;
import org.nd4j.parameterserver.distributed.transport.RoutedTransport;
import org.nd4j.parameterserver.distributed.transport.Transport;
import org.nd4j.parameterserver.distributed.util.NetworkOrganizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
public class VoidParameterServer {
    private static final Logger log = LoggerFactory.getLogger(VoidParameterServer.class);
    private static final VoidParameterServer INSTANCE = new VoidParameterServer();
    protected volatile NodeRole nodeRole;
    protected volatile VoidConfiguration voidConfiguration;
    protected AtomicBoolean initLocker = new AtomicBoolean(false);
    protected AtomicBoolean initFinished = new AtomicBoolean(false);
    protected AtomicBoolean shutdownLocker = new AtomicBoolean(false);
    protected AtomicBoolean shutdownFinished = new AtomicBoolean(false);
    protected transient Transport transport;
    protected transient AtomicBoolean manualMode = new AtomicBoolean(false);
    protected transient AtomicBoolean runner = new AtomicBoolean(false);
    protected transient Thread[] processingThreads;
    protected transient Runnable[] processingRunnables;
    protected transient TrainingDriver<? extends TrainingMessage> trainer;
    protected short shardIndex;
    protected Clipboard clipboard = new Clipboard();
    protected Storage storage = new WordVectorStorage();
    protected Map<String, Frame<TrainingMessage>> frames = new ConcurrentHashMap<String, Frame<TrainingMessage>>();
    protected static final int numThreads = Runtime.getRuntime().availableProcessors() * 2;
    protected ThreadPoolExecutor executor = (ThreadPoolExecutor)Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
    protected static double MAX_EXP = 6.0;

    protected VoidParameterServer() {
        this.nodeRole = NodeRole.NONE;
    }

    protected VoidParameterServer(@NonNull NodeRole nodeRole) {
        if (nodeRole == null) {
            throw new NullPointerException("nodeRole is marked @NonNull but is null");
        }
        this.nodeRole = nodeRole;
    }

    protected VoidParameterServer(boolean manualMode) {
        this();
        this.manualMode.set(manualMode);
    }

    public static VoidParameterServer getInstance() {
        return INSTANCE;
    }

    public void setTrainingDriver(@NonNull TrainingDriver<? extends TrainingMessage> trainer) {
        if (trainer == null) {
            throw new NullPointerException("trainer is marked @NonNull but is null");
        }
        this.trainer = trainer;
    }

    public short getShardIndex() {
        return this.shardIndex;
    }

    protected void setIpPortForShard(String ip, int port) {
        this.transport.setIpAndPort(ip, port);
    }

    protected void setShardIndex(short idx) {
        this.shardIndex = idx;
    }

    protected Transport getTransport() {
        return this.transport;
    }

    protected INDArray getSyn0() {
        return this.storage.getArray(WordVectorStorage.SYN_0);
    }

    protected INDArray getSyn1() {
        return this.storage.getArray(WordVectorStorage.SYN_1);
    }

    protected INDArray getSyn1Neg() {
        return this.storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
    }

    protected INDArray getExpTable() {
        return this.storage.getArray(WordVectorStorage.EXP_TABLE);
    }

    protected INDArray getNegTable() {
        return this.storage.getArray(WordVectorStorage.NEGATIVE_TABLE);
    }

    protected void init(@NonNull VoidConfiguration voidConfiguration) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked @NonNull but is null");
        }
        this.init(voidConfiguration, new RoutedTransport(), new SkipGramTrainer());
    }

    public boolean isInit() {
        return this.initFinished.get();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, TrainingDriver<? extends TrainingMessage> trainer) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked @NonNull but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked @NonNull but is null");
        }
        if (this.initFinished.get()) {
            return;
        }
        VoidParameterServer voidParameterServer = this;
        synchronized (voidParameterServer) {
            if (this.initLocker.compareAndSet(false, true)) {
                this.trainer = trainer;
                this.voidConfiguration = voidConfiguration;
                this.transport = transport;
                if (this.nodeRole == NodeRole.NONE && (voidConfiguration.getForcedRole() == null || voidConfiguration.getForcedRole() == NodeRole.NONE)) {
                    Pair pair = null;
                    pair = voidConfiguration.getShardAddresses().size() == 1 && voidConfiguration.getShardAddresses().get(0).contains("127.0.0.1") ? Pair.create((Object)((Object)NodeRole.SHARD), (Object)voidConfiguration.getShardAddresses().get(0)) : this.getRole(voidConfiguration, VoidParameterServer.getLocalAddresses());
                    this.nodeRole = (NodeRole)((Object)pair.getFirst());
                    String ipAndPort = (String)pair.getSecond();
                    String ip = "127.0.0.1";
                    int port = 0;
                    if (ipAndPort.contains(":")) {
                        String[] split = ipAndPort.split(":");
                        ip = split[0];
                        port = Integer.valueOf(split[1]);
                    } else {
                        ip = ipAndPort;
                        port = voidConfiguration.getUnicastControllerPort();
                    }
                    if (this.nodeRole == NodeRole.SHARD && voidConfiguration.getShardAddresses().size() > 1) {
                        short cnt = 0;
                        for (String shard : voidConfiguration.getShardAddresses()) {
                            String lIp = null;
                            if (shard.contains(":")) {
                                String[] split = ipAndPort.split(":");
                                lIp = split[0];
                            } else {
                                lIp = shard;
                            }
                            if (lIp.equals(ip)) {
                                this.shardIndex = cnt;
                            }
                            cnt = (short)(cnt + 1);
                        }
                    }
                    this.transport.init(voidConfiguration, this.clipboard, this.nodeRole, ip, port, this.shardIndex);
                } else {
                    if (this.nodeRole == NodeRole.NONE) {
                        this.nodeRole = voidConfiguration.getForcedRole();
                    }
                    String localIp = voidConfiguration.getExecutionMode() == ExecutionMode.MANAGED ? voidConfiguration.getControllerAddress() : "127.0.0.1";
                    this.transport.init(voidConfiguration, this.clipboard, this.nodeRole, localIp, voidConfiguration.getUnicastControllerPort(), this.shardIndex);
                }
                if (!this.manualMode.get()) {
                    this.processingThreads = new Thread[numThreads];
                    this.processingRunnables = new Runnable[numThreads];
                    for (int x = 0; x < numThreads; ++x) {
                        this.processingThreads[x] = new Thread(() -> {
                            this.runner.set(true);
                            while (this.runner.get()) {
                                try {
                                    this.handleMessage(transport.takeMessage());
                                }
                                catch (ND4JIllegalStateException e) {
                                    throw new RuntimeException(e);
                                }
                                catch (Exception e) {
                                    throw new RuntimeException(e);
                                }
                            }
                        });
                        Nd4j.getAffinityManager().attachThreadToDevice(this.processingThreads[x], Nd4j.getAffinityManager().getDeviceForCurrentThread());
                        this.processingThreads[x].setDaemon(true);
                        this.processingThreads[x].setName("VoidParameterServer messages handling thread");
                        this.processingThreads[x].start();
                    }
                }
                log.info("Launching transport...");
                transport.launch(Transport.ThreadingModel.DEDICATED_THREADS);
                trainer.init(this.voidConfiguration, this.transport, this.storage, this.clipboard);
                this.initFinished.set(true);
            }
        }
    }

    protected VoidParameterServer toggleManualMode(boolean mode) {
        this.manualMode.set(mode);
        return this;
    }

    protected Pair<NodeRole, String> getRole(@NonNull VoidConfiguration voidConfiguration, @NonNull Collection<String> localIPs) {
        String sparkIp;
        String cleansed;
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked @NonNull but is null");
        }
        if (localIPs == null) {
            throw new NullPointerException("localIPs is marked @NonNull but is null");
        }
        NodeRole result = NodeRole.CLIENT;
        for (String ip : voidConfiguration.getShardAddresses()) {
            cleansed = ip.replaceAll(":.*", "");
            if (!localIPs.contains(cleansed)) continue;
            return Pair.create((Object)((Object)NodeRole.SHARD), (Object)ip);
        }
        if (voidConfiguration.getBackupAddresses() != null) {
            for (String ip : voidConfiguration.getBackupAddresses()) {
                cleansed = ip.replaceAll(":.*", "");
                if (!localIPs.contains(cleansed)) continue;
                return Pair.create((Object)((Object)NodeRole.BACKUP), (Object)ip);
            }
        }
        if ((sparkIp = null) == null && voidConfiguration.getNetworkMask() != null) {
            NetworkOrganizer organizer = new NetworkOrganizer(voidConfiguration.getNetworkMask());
            sparkIp = organizer.getMatchingAddress();
        }
        if (sparkIp == null) {
            sparkIp = System.getenv("DL4J_VOID_IP");
        }
        log.info("Got [{}] as sparkIp", (Object)sparkIp);
        if (sparkIp == null) {
            throw new ND4JIllegalStateException("Can't get IP address for UDP communcation");
        }
        return Pair.create((Object)((Object)result), (Object)(sparkIp + ":" + voidConfiguration.getUnicastControllerPort()));
    }

    public void shutdown() {
        if (this.initLocker.get() && this.shutdownLocker.compareAndSet(false, true)) {
            log.info("Shutting down transport...");
            this.transport.shutdown();
            this.executor.shutdown();
            this.initFinished.set(false);
            this.initLocker.set(false);
            this.shutdownLocker.set(false);
        }
    }

    public static Set<String> getLocalAddresses() {
        try {
            ArrayList<NetworkInterface> interfaces = Collections.list(NetworkInterface.getNetworkInterfaces());
            HashSet<String> result = new HashSet<String>();
            for (NetworkInterface networkInterface : interfaces) {
                if (networkInterface.isLoopback() || !networkInterface.isUp()) continue;
                for (InterfaceAddress address : networkInterface.getInterfaceAddresses()) {
                    String addr = address.getAddress().getHostAddress();
                    if (addr == null || addr.isEmpty() || addr.contains(":")) continue;
                    result.add(addr);
                }
            }
            return result;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected void handleMessage(@NonNull VoidMessage message) {
        if (message == null) {
            throw new NullPointerException("message is marked @NonNull but is null");
        }
        if (message == null) {
            return;
        }
        if (message.getTargetId() >= 0 && message.getTargetId() != this.shardIndex) {
            log.warn("sI_{}: Skipping message: [{}]; TargetIdx: [{}]", new Object[]{this.shardIndex, message.getClass().getSimpleName(), message.getTargetId()});
            return;
        }
        message.attachContext(this.voidConfiguration, this.trainer, this.clipboard, this.transport, this.storage, this.nodeRole, this.shardIndex);
        message.processMessage();
    }

    public void initializeSeqVec(int vectorLength, int numWords, long seed, int columnsPerShard, boolean useHs, boolean useNegSampling) {
        InitializationRequestMessage dim = new InitializationRequestMessage(vectorLength, numWords, seed, useHs, useNegSampling, columnsPerShard);
        this.transport.sendMessage(dim);
    }

    public synchronized void execDistributed(@NonNull TrainingMessage message) {
        if (message == null) {
            throw new NullPointerException("message is marked @NonNull but is null");
        }
        Frame<TrainingMessage> currentFrame = this.frames.get(message.getClass().getSimpleName());
        if (currentFrame == null) {
            currentFrame = new Frame(BasicSequenceProvider.getInstance().getNextValue());
            this.frames.put(message.getClass().getSimpleName(), currentFrame);
        }
        currentFrame.stackMessage(message);
        if (currentFrame.size() >= 128) {
            this.transport.sendMessage(currentFrame);
            currentFrame = new Frame(BasicSequenceProvider.getInstance().getNextValue());
            this.frames.put(message.getClass().getSimpleName(), currentFrame);
        }
    }

    public void execDistributedImmediately(@NonNull TrainingMessage message) {
        if (message == null) {
            throw new NullPointerException("message is marked @NonNull but is null");
        }
        this.transport.sendMessageToAllShards(message);
    }

    public void execDistributed(@NonNull Frame<? extends TrainingMessage> messages) {
        if (messages == null) {
            throw new NullPointerException("messages is marked @NonNull but is null");
        }
        this.transport.sendMessage(messages);
    }

    public INDArray getVector(int rowIdx) {
        return this.getVector(WordVectorStorage.SYN_0, rowIdx);
    }

    public INDArray getVector(@NonNull Integer key, int rowIdx) {
        if (key == null) {
            throw new NullPointerException("key is marked @NonNull but is null");
        }
        VectorRequestMessage message = new VectorRequestMessage(key, rowIdx);
        MeaningfulMessage response = this.transport.sendMessageAndGetResponse(message);
        return response.getPayload();
    }

    public synchronized void sendMessageToAllShards(@NonNull VoidMessage message) {
        if (message == null) {
            throw new NullPointerException("message is marked @NonNull but is null");
        }
        this.transport.sendMessageToAllShards(message);
    }

    public void sendMessageToAllClients(@NonNull VoidMessage message) {
        if (message == null) {
            throw new NullPointerException("message is marked @NonNull but is null");
        }
        this.sendMessageToAllClients(message, null);
    }

    public synchronized void sendMessageToAllClients(@NonNull VoidMessage message, Long ... exclusions) {
        if (message == null) {
            throw new NullPointerException("message is marked @NonNull but is null");
        }
        this.transport.sendMessageToAllClients(message, new Long[0]);
    }

    public NodeRole getNodeRole() {
        return this.nodeRole;
    }
}

