/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RescaledChannelsMapping;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.Preconditions;

class TaskStateAssignment {
    final ExecutionJobVertex executionJobVertex;
    final Map<OperatorID, OperatorState> oldState;
    final boolean hasState;
    final int newParallelism;
    final OperatorID inputOperatorID;
    final OperatorID outputOperatorID;
    final Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState;
    final Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState;
    final Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState;
    final Map<OperatorInstanceID, List<KeyedStateHandle>> subRawKeyedState;
    final Map<OperatorInstanceID, List<InputChannelStateHandle>> inputChannelStates;
    final Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> resultSubpartitionStates;
    Map<Integer, Set<Integer>> outputSubtaskMappings = Collections.emptyMap();
    Map<Integer, Set<Integer>> inputSubtaskMappings = Collections.emptyMap();
    final Map<Integer, TaskStateAssignment> upstreamAssignments;
    final Map<Integer, TaskStateAssignment> downstreamAssignments;

    public TaskStateAssignment(ExecutionJobVertex executionJobVertex, Map<OperatorID, OperatorState> oldState) {
        this.executionJobVertex = executionJobVertex;
        this.oldState = oldState;
        this.hasState = oldState.values().stream().anyMatch(operatorState -> operatorState.getNumberCollectedStates() > 0);
        this.newParallelism = executionJobVertex.getParallelism();
        int expectedNumberOfSubtasks = this.newParallelism * oldState.size();
        this.subManagedOperatorState = new HashMap<OperatorInstanceID, List<OperatorStateHandle>>(expectedNumberOfSubtasks);
        this.subRawOperatorState = new HashMap<OperatorInstanceID, List<OperatorStateHandle>>(expectedNumberOfSubtasks);
        this.inputChannelStates = new HashMap<OperatorInstanceID, List<InputChannelStateHandle>>(expectedNumberOfSubtasks);
        this.resultSubpartitionStates = new HashMap<OperatorInstanceID, List<ResultSubpartitionStateHandle>>(expectedNumberOfSubtasks);
        this.subManagedKeyedState = new HashMap<OperatorInstanceID, List<KeyedStateHandle>>(expectedNumberOfSubtasks);
        this.subRawKeyedState = new HashMap<OperatorInstanceID, List<KeyedStateHandle>>(expectedNumberOfSubtasks);
        List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
        this.outputOperatorID = operatorIDs.get(0).getGeneratedOperatorID();
        this.inputOperatorID = operatorIDs.get(operatorIDs.size() - 1).getGeneratedOperatorID();
        this.upstreamAssignments = new HashMap<Integer, TaskStateAssignment>(executionJobVertex.getInputs().size());
        this.downstreamAssignments = new HashMap<Integer, TaskStateAssignment>(executionJobVertex.getProducedDataSets().length);
    }

    public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
        Preconditions.checkState((this.subManagedKeyedState.containsKey(instanceID) || !this.subRawKeyedState.containsKey(instanceID) ? 1 : 0) != 0, (Object)"If an operator has no managed key state, it should also not have a raw keyed state.");
        return OperatorSubtaskState.builder().setManagedOperatorState(this.getState(instanceID, this.subManagedOperatorState)).setRawOperatorState(this.getState(instanceID, this.subRawOperatorState)).setManagedKeyedState(this.getState(instanceID, this.subManagedKeyedState)).setRawKeyedState(this.getState(instanceID, this.subRawKeyedState)).setInputChannelState(this.getState(instanceID, this.inputChannelStates)).setResultSubpartitionState(this.getState(instanceID, this.resultSubpartitionStates)).setInputRescalingDescriptor(this.inputOperatorID.equals((Object)instanceID.getOperatorId()) ? this.createRescalingDescriptor(instanceID, this.upstreamAssignments, assignment -> assignment.outputSubtaskMappings, this.inputSubtaskMappings) : InflightDataRescalingDescriptor.NO_RESCALE).setOutputRescalingDescriptor(this.outputOperatorID.equals((Object)instanceID.getOperatorId()) ? this.createRescalingDescriptor(instanceID, this.downstreamAssignments, assignment -> assignment.inputSubtaskMappings, this.outputSubtaskMappings) : InflightDataRescalingDescriptor.NO_RESCALE).build();
    }

    private InflightDataRescalingDescriptor createRescalingDescriptor(OperatorInstanceID instanceID, Map<Integer, TaskStateAssignment> assignments, Function<TaskStateAssignment, Map<Integer, Set<Integer>>> mappingRetriever, Map<Integer, Set<Integer>> subtaskMappings) {
        if (assignments.isEmpty() && subtaskMappings.isEmpty()) {
            return InflightDataRescalingDescriptor.NO_RESCALE;
        }
        Set<Integer> oldTaskInstances = subtaskMappings.isEmpty() ? InflightDataRescalingDescriptor.NO_SUBTASKS : subtaskMappings.get(instanceID.getSubtaskId());
        Map<Integer, RescaledChannelsMapping> rescaledChannelsMappings = assignments.isEmpty() ? InflightDataRescalingDescriptor.NO_MAPPINGS : assignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, assignment -> new RescaledChannelsMapping((Map)mappingRetriever.apply((TaskStateAssignment)assignment.getValue()))));
        return new InflightDataRescalingDescriptor(oldTaskInstances, rescaledChannelsMappings);
    }

    private <T extends StateObject> StateObjectCollection<T> getState(OperatorInstanceID instanceID, Map<OperatorInstanceID, List<T>> subManagedOperatorState) {
        List<T> value = subManagedOperatorState.get(instanceID);
        return value != null ? new StateObjectCollection<T>(value) : StateObjectCollection.empty();
    }
}

