/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed.v2.util;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Optional;
import org.nd4j.linalg.util.ND4JFileUtils;
import org.nd4j.linalg.util.SerializationUtils;
import org.nd4j.parameterserver.distributed.v2.chunks.ChunksTracker;
import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk;
import org.nd4j.parameterserver.distributed.v2.chunks.impl.FileChunksTracker;
import org.nd4j.parameterserver.distributed.v2.chunks.impl.InmemoryChunksTracker;
import org.nd4j.parameterserver.distributed.v2.messages.VoidMessage;

public class MessageSplitter {
    private static final MessageSplitter INSTANCE = new MessageSplitter();
    protected Map<String, ChunksTracker> trackers = new ConcurrentHashMap<String, ChunksTracker>();
    protected final AtomicLong memoryUse = new AtomicLong(0L);

    public static MessageSplitter getInstance() {
        return INSTANCE;
    }

    public Collection<VoidChunk> split(@NonNull VoidMessage message, int maxBytes) throws IOException {
        if (message == null) {
            throw new NullPointerException("message is marked @NonNull but is null");
        }
        if (maxBytes <= 0) {
            throw new ND4JIllegalStateException("MaxBytes must be > 0");
        }
        File tempFile = ND4JFileUtils.createTempFile((String)"messageSplitter", (String)"temp");
        ArrayList<VoidChunk> result = new ArrayList<VoidChunk>();
        try (FileOutputStream fos = new FileOutputStream(tempFile);
             BufferedOutputStream bos = new BufferedOutputStream(fos);){
            SerializationUtils.serialize((Serializable)message, (OutputStream)fos);
            long length = tempFile.length();
            int numChunks = (int)(length / (long)maxBytes + (long)(length % (long)maxBytes > 0L ? 1 : 0));
            try (FileInputStream fis = new FileInputStream(tempFile);
                 BufferedInputStream bis = new BufferedInputStream(fis);){
                byte[] bytes = new byte[maxBytes];
                int cnt = 0;
                int id = 0;
                while ((long)cnt < length) {
                    int c = bis.read(bytes);
                    byte[] tmp = Arrays.copyOf(bytes, c);
                    VoidChunk msg = VoidChunk.builder().messageId(UUID.randomUUID().toString()).originalId(message.getMessageId()).chunkId(id++).numberOfChunks(numChunks).splitSize(maxBytes).payload(tmp).totalSize(length).build();
                    result.add(msg);
                    cnt += c;
                }
            }
        }
        tempFile.delete();
        return result;
    }

    boolean isTrackedMessage(@NonNull String messageId) {
        if (messageId == null) {
            throw new NullPointerException("messageId is marked @NonNull but is null");
        }
        return this.trackers.containsKey(messageId);
    }

    boolean isTrackedMessage(@NonNull VoidChunk chunk) {
        if (chunk == null) {
            throw new NullPointerException("chunk is marked @NonNull but is null");
        }
        return this.isTrackedMessage(chunk.getOriginalId());
    }

    public <T extends VoidMessage> Optional<T> merge(@NonNull VoidChunk chunk) {
        if (chunk == null) {
            throw new NullPointerException("chunk is marked @NonNull but is null");
        }
        return this.merge(chunk, -1L);
    }

    public void release(String messageId) {
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public <T extends VoidMessage> Optional<T> merge(@NonNull VoidChunk chunk, long memoryLimit) {
        if (chunk == null) {
            throw new NullPointerException("chunk is marked @NonNull but is null");
        }
        String originalId = chunk.getOriginalId();
        AtomicBoolean checker = new AtomicBoolean(false);
        ChunksTracker tracker = null;
        if (this.memoryUse.get() + chunk.getTotalSize() < memoryLimit) {
            tracker = new InmemoryChunksTracker(chunk);
            if ((tracker = (ChunksTracker)this.trackers.putIfAbsent(originalId, tracker)) == null) {
                this.memoryUse.addAndGet(chunk.getTotalSize());
            }
        } else {
            tracker = new FileChunksTracker(chunk);
            tracker = this.trackers.putIfAbsent(originalId, tracker);
        }
        if (tracker == null) {
            tracker = this.trackers.get(chunk.getOriginalId());
        }
        if (tracker.append(chunk)) {
            try {
                Optional optional = Optional.of(tracker.getMessage());
                return optional;
            }
            finally {
                if (tracker instanceof InmemoryChunksTracker) {
                    this.memoryUse.addAndGet(-chunk.getTotalSize());
                }
                tracker.release();
                this.trackers.remove(chunk.getOriginalId());
            }
        }
        return Optional.empty();
    }

    public void reset() {
        this.memoryUse.set(0L);
        this.trackers.clear();
    }
}

