package org.apache.flink.runtime.checkpoint;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.IntStream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.metrics.scope.ScopeFormat;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apache/flink/runtime/checkpoint/TaskStateAssignment.class */
public class TaskStateAssignment {
    private static final Logger LOG = LoggerFactory.getLogger(TaskStateAssignment.class);
    final ExecutionJobVertex executionJobVertex;
    final Map<OperatorID, OperatorState> oldState;
    final boolean hasNonFinishedState;
    final boolean isFullyFinished;
    final boolean hasInputState;
    final boolean hasOutputState;
    final int newParallelism;
    final OperatorID inputOperatorID;
    final OperatorID outputOperatorID;
    final Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState;
    final Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState;
    final Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState;
    final Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedColdKeyedState;
    final Map<OperatorInstanceID, List<KeyedStateHandle>> subRawKeyedState;
    final Map<OperatorInstanceID, List<InputChannelStateHandle>> inputChannelStates;
    final Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> resultSubpartitionStates;
    private final Map<Integer, SubtasksRescaleMapping> outputSubtaskMappings = new HashMap();
    private final Map<Integer, SubtasksRescaleMapping> inputSubtaskMappings = new HashMap();

    @Nullable
    private TaskStateAssignment[] downstreamAssignments;

    @Nullable
    private TaskStateAssignment[] upstreamAssignments;
    private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment;
    private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/flink/runtime/checkpoint/TaskStateAssignment$SubtasksRescaleMapping.class */
    public static class SubtasksRescaleMapping {
        private final RescaleMappings rescaleMappings;
        private final boolean mayHaveAmbiguousSubtasks;

        private SubtasksRescaleMapping(RescaleMappings rescaleMappings, boolean z) {
            this.rescaleMappings = rescaleMappings;
            this.mayHaveAmbiguousSubtasks = z;
        }

        public RescaleMappings getRescaleMappings() {
            return this.rescaleMappings;
        }

        public boolean isMayHaveAmbiguousSubtasks() {
            return this.mayHaveAmbiguousSubtasks;
        }
    }

    public TaskStateAssignment(ExecutionJobVertex executionJobVertex, Map<OperatorID, OperatorState> map, Map<IntermediateDataSetID, TaskStateAssignment> map2, Map<ExecutionJobVertex, TaskStateAssignment> map3) {
        this.executionJobVertex = executionJobVertex;
        this.oldState = map;
        this.hasNonFinishedState = map.values().stream().anyMatch(operatorState -> {
            return operatorState.getNumberCollectedStates() > 0;
        });
        this.isFullyFinished = map.values().stream().anyMatch((v0) -> {
            return v0.isFullyFinished();
        });
        if (this.isFullyFinished) {
            Preconditions.checkState(map.values().stream().allMatch((v0) -> {
                return v0.isFullyFinished();
            }), "JobVertex could not have mixed finished and unfinished operators");
        }
        this.newParallelism = executionJobVertex.getParallelism();
        this.consumerAssignment = (Map) Preconditions.checkNotNull(map2);
        this.vertexAssignments = (Map) Preconditions.checkNotNull(map3);
        int size = this.newParallelism * map.size();
        this.subManagedOperatorState = new HashMap(size);
        this.subRawOperatorState = new HashMap(size);
        this.inputChannelStates = new HashMap(size);
        this.resultSubpartitionStates = new HashMap(size);
        this.subManagedKeyedState = new HashMap(size);
        this.subManagedColdKeyedState = new HashMap(size);
        this.subRawKeyedState = new HashMap(size);
        List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
        this.outputOperatorID = operatorIDs.get(0).getGeneratedOperatorID();
        this.inputOperatorID = operatorIDs.get(operatorIDs.size() - 1).getGeneratedOperatorID();
        this.hasInputState = map.get(this.inputOperatorID).getStates().stream().anyMatch(operatorSubtaskState -> {
            return !operatorSubtaskState.getInputChannelState().isEmpty();
        });
        this.hasOutputState = map.get(this.outputOperatorID).getStates().stream().anyMatch(operatorSubtaskState2 -> {
            return !operatorSubtaskState2.getResultSubpartitionState().isEmpty();
        });
    }

    public TaskStateAssignment[] getDownstreamAssignments() {
        if (this.downstreamAssignments == null) {
            this.downstreamAssignments = (TaskStateAssignment[]) Arrays.stream(this.executionJobVertex.getProducedDataSets()).map(intermediateResult -> {
                return this.consumerAssignment.get(intermediateResult.getId());
            }).toArray(i -> {
                return new TaskStateAssignment[i];
            });
        }
        return this.downstreamAssignments;
    }

    private static int getAssignmentIndex(TaskStateAssignment[] taskStateAssignmentArr, TaskStateAssignment taskStateAssignment) {
        return Arrays.asList(taskStateAssignmentArr).indexOf(taskStateAssignment);
    }

    public TaskStateAssignment[] getUpstreamAssignments() {
        if (this.upstreamAssignments == null) {
            this.upstreamAssignments = (TaskStateAssignment[]) this.executionJobVertex.getInputs().stream().map(intermediateResult -> {
                return this.vertexAssignments.get(intermediateResult.getProducer());
            }).toArray(i -> {
                return new TaskStateAssignment[i];
            });
        }
        return this.upstreamAssignments;
    }

    public OperatorSubtaskState getSubtaskState(OperatorInstanceID operatorInstanceID) {
        Preconditions.checkState(this.subManagedKeyedState.containsKey(operatorInstanceID) || !this.subRawKeyedState.containsKey(operatorInstanceID), "If an operator has no managed key state, it should also not have a raw keyed state.");
        return OperatorSubtaskState.builder().setManagedOperatorState(getState(operatorInstanceID, this.subManagedOperatorState)).setRawOperatorState(getState(operatorInstanceID, this.subRawOperatorState)).setManagedKeyedState(getState(operatorInstanceID, this.subManagedKeyedState)).setManagedColdKeyedState(getState(operatorInstanceID, this.subManagedColdKeyedState)).setRawKeyedState(getState(operatorInstanceID, this.subRawKeyedState)).setInputChannelState(getState(operatorInstanceID, this.inputChannelStates)).setResultSubpartitionState(getState(operatorInstanceID, this.resultSubpartitionStates)).setInputRescalingDescriptor(createRescalingDescriptor(operatorInstanceID, this.inputOperatorID, getUpstreamAssignments(), (taskStateAssignment, bool) -> {
            return taskStateAssignment.getOutputMapping(getAssignmentIndex(taskStateAssignment.getDownstreamAssignments(), this), bool.booleanValue());
        }, this.inputSubtaskMappings, (v1) -> {
            return getInputMapping(v1);
        })).setOutputRescalingDescriptor(createRescalingDescriptor(operatorInstanceID, this.outputOperatorID, getDownstreamAssignments(), (taskStateAssignment2, bool2) -> {
            return taskStateAssignment2.getInputMapping(getAssignmentIndex(taskStateAssignment2.getUpstreamAssignments(), this), bool2.booleanValue());
        }, this.outputSubtaskMappings, (v1) -> {
            return getOutputMapping(v1);
        })).build();
    }

    private InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor log(InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor inflightDataGateOrPartitionRescalingDescriptor, int i, int i2) {
        LOG.debug("created {} for task={} subtask={} partition={}", new Object[]{inflightDataGateOrPartitionRescalingDescriptor, this.executionJobVertex.getName(), Integer.valueOf(i), Integer.valueOf(i2)});
        return inflightDataGateOrPartitionRescalingDescriptor;
    }

    private InflightDataRescalingDescriptor log(InflightDataRescalingDescriptor inflightDataRescalingDescriptor, int i) {
        LOG.debug("created {} for task={} subtask={}", new Object[]{inflightDataRescalingDescriptor, this.executionJobVertex.getName(), Integer.valueOf(i)});
        return inflightDataRescalingDescriptor;
    }

    private InflightDataRescalingDescriptor createRescalingDescriptor(OperatorInstanceID operatorInstanceID, OperatorID operatorID, TaskStateAssignment[] taskStateAssignmentArr, BiFunction<TaskStateAssignment, Boolean, SubtasksRescaleMapping> biFunction, Map<Integer, SubtasksRescaleMapping> map, Function<Integer, SubtasksRescaleMapping> function) {
        if (!operatorID.equals(operatorInstanceID.getOperatorId())) {
            return InflightDataRescalingDescriptor.NO_RESCALE;
        }
        SubtasksRescaleMapping[] subtasksRescaleMappingArr = (SubtasksRescaleMapping[]) Arrays.stream(taskStateAssignmentArr).map(taskStateAssignment -> {
            return (SubtasksRescaleMapping) biFunction.apply(taskStateAssignment, false);
        }).toArray(i -> {
            return new SubtasksRescaleMapping[i];
        });
        if (map.isEmpty() && Arrays.stream(subtasksRescaleMappingArr).allMatch((v0) -> {
            return Objects.isNull(v0);
        })) {
            return InflightDataRescalingDescriptor.NO_RESCALE;
        }
        InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor[] createGateOrPartitionRescalingDescriptors = createGateOrPartitionRescalingDescriptors(operatorInstanceID, taskStateAssignmentArr, taskStateAssignment2 -> {
            return (SubtasksRescaleMapping) biFunction.apply(taskStateAssignment2, true);
        }, map, function, subtasksRescaleMappingArr);
        return Arrays.stream(createGateOrPartitionRescalingDescriptors).allMatch((v0) -> {
            return v0.isIdentity();
        }) ? log(InflightDataRescalingDescriptor.NO_RESCALE, operatorInstanceID.getSubtaskId()) : log(new InflightDataRescalingDescriptor(createGateOrPartitionRescalingDescriptors), operatorInstanceID.getSubtaskId());
    }

    private InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor[] createGateOrPartitionRescalingDescriptors(OperatorInstanceID operatorInstanceID, TaskStateAssignment[] taskStateAssignmentArr, Function<TaskStateAssignment, SubtasksRescaleMapping> function, Map<Integer, SubtasksRescaleMapping> map, Function<Integer, SubtasksRescaleMapping> function2, SubtasksRescaleMapping[] subtasksRescaleMappingArr) {
        return (InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor[]) IntStream.range(0, subtasksRescaleMappingArr.length).mapToObj(i -> {
            TaskStateAssignment taskStateAssignment = taskStateAssignmentArr[i];
            return getInflightDataGateOrPartitionRescalingDescriptor(operatorInstanceID, i, (SubtasksRescaleMapping) Optional.ofNullable(subtasksRescaleMappingArr[i]).orElseGet(() -> {
                return (SubtasksRescaleMapping) function.apply(taskStateAssignment);
            }), (SubtasksRescaleMapping) Optional.ofNullable(map.get(Integer.valueOf(i))).orElseGet(() -> {
                return (SubtasksRescaleMapping) function2.apply(Integer.valueOf(i));
            }));
        }).toArray(i2 -> {
            return new InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor[i2];
        });
    }

    private InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor getInflightDataGateOrPartitionRescalingDescriptor(OperatorInstanceID operatorInstanceID, int i, SubtasksRescaleMapping subtasksRescaleMapping, SubtasksRescaleMapping subtasksRescaleMapping2) {
        int[] mappedIndexes = subtasksRescaleMapping2.rescaleMappings.getMappedIndexes(operatorInstanceID.getSubtaskId());
        boolean z = (subtasksRescaleMapping2.rescaleMappings.isIdentity() && subtasksRescaleMapping.getRescaleMappings().isIdentity()) || mappedIndexes.length == 0;
        return log(new InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor(mappedIndexes, subtasksRescaleMapping.getRescaleMappings(), subtasksRescaleMapping2.mayHaveAmbiguousSubtasks ? subtasksRescaleMapping2.rescaleMappings.getAmbiguousTargets() : Collections.emptySet(), z ? InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType.IDENTITY : InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType.RESCALING), operatorInstanceID.getSubtaskId(), i);
    }

    private <T extends StateObject> StateObjectCollection<T> getState(OperatorInstanceID operatorInstanceID, Map<OperatorInstanceID, List<T>> map) {
        List<T> list = map.get(operatorInstanceID);
        return list != null ? new StateObjectCollection<>(list) : StateObjectCollection.empty();
    }

    private SubtasksRescaleMapping getOutputMapping(int i, boolean z) {
        SubtasksRescaleMapping subtasksRescaleMapping = this.outputSubtaskMappings.get(Integer.valueOf(i));
        return (z && subtasksRescaleMapping == null) ? getOutputMapping(i) : subtasksRescaleMapping;
    }

    private SubtasksRescaleMapping getInputMapping(int i, boolean z) {
        SubtasksRescaleMapping subtasksRescaleMapping = this.inputSubtaskMappings.get(Integer.valueOf(i));
        return (z && subtasksRescaleMapping == null) ? getInputMapping(i) : subtasksRescaleMapping;
    }

    public SubtasksRescaleMapping getOutputMapping(int i) {
        TaskStateAssignment taskStateAssignment = getDownstreamAssignments()[i];
        SubtaskStateMapper subtaskStateMapper = (SubtaskStateMapper) Preconditions.checkNotNull(taskStateAssignment.executionJobVertex.getJobVertex().getInputs().get(taskStateAssignment.executionJobVertex.getInputs().indexOf(this.executionJobVertex.getProducedDataSets()[i])).getUpstreamSubtaskStateMapper(), "No channel rescaler found during rescaling of channel state");
        RescaleMappings newToOldSubtasksMapping = subtaskStateMapper.getNewToOldSubtasksMapping(this.oldState.get(this.outputOperatorID).getParallelism(), this.newParallelism);
        return this.outputSubtaskMappings.compute(Integer.valueOf(i), (num, subtasksRescaleMapping) -> {
            return checkSubtaskMapping(subtasksRescaleMapping, newToOldSubtasksMapping, subtaskStateMapper.isAmbiguous());
        });
    }

    public SubtasksRescaleMapping getInputMapping(int i) {
        SubtaskStateMapper subtaskStateMapper = (SubtaskStateMapper) Preconditions.checkNotNull(this.executionJobVertex.getJobVertex().getInputs().get(i).getDownstreamSubtaskStateMapper(), "No channel rescaler found during rescaling of channel state");
        RescaleMappings newToOldSubtasksMapping = subtaskStateMapper.getNewToOldSubtasksMapping(this.oldState.get(this.inputOperatorID).getParallelism(), this.newParallelism);
        return this.inputSubtaskMappings.compute(Integer.valueOf(i), (num, subtasksRescaleMapping) -> {
            return checkSubtaskMapping(subtasksRescaleMapping, newToOldSubtasksMapping, subtaskStateMapper.isAmbiguous());
        });
    }

    public String toString() {
        return "TaskStateAssignment for " + this.executionJobVertex.getName();
    }

    /* JADX INFO: Access modifiers changed from: private */
    @Nonnull
    public static SubtasksRescaleMapping checkSubtaskMapping(@Nullable SubtasksRescaleMapping subtasksRescaleMapping, RescaleMappings rescaleMappings, boolean z) {
        if (subtasksRescaleMapping == null) {
            return new SubtasksRescaleMapping(rescaleMappings, z);
        }
        if (subtasksRescaleMapping.rescaleMappings.equals(rescaleMappings)) {
            return new SubtasksRescaleMapping(rescaleMappings, subtasksRescaleMapping.mayHaveAmbiguousSubtasks || z);
        }
        throw new IllegalStateException("Incompatible subtask mappings: are multiple operators ingesting/producing intermediate results with varying degrees of parallelism?Found " + subtasksRescaleMapping + " and " + rescaleMappings + ScopeFormat.SCOPE_SEPARATOR);
    }
}
