package org.apache.flink.streaming.runtime.io.rescaling;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
import org.apache.flink.runtime.jobgraph.tasks.RescalableTask;
import org.apache.flink.runtime.rescale.RuntimeRescaleEvent;
import org.apache.flink.runtime.rescale.RuntimeRescaleException;
import org.apache.flink.streaming.runtime.io.checkpointing.CheckpointableInputs;
import org.apache.flink.streaming.runtime.io.rescaling.RuntimeRescaleEventHandlerState;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/streaming/runtime/io/rescaling/BlockedRuntimeRescaleEventHandler.class */
public class BlockedRuntimeRescaleEventHandler extends RuntimeRescaleEventHandler {
    private static final Logger LOG = LoggerFactory.getLogger(BlockedRuntimeRescaleEventHandler.class);
    private volatile String taskName;
    private final ControllerImpl context;
    private long currentRescaleEventId;
    private final CheckpointableInputs inputs;
    private final Set<InputChannelInfo> alignedChannels;
    private int targetChannelCount;
    private RuntimeRescaleEventHandlerState currentState;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/streaming/runtime/io/rescaling/BlockedRuntimeRescaleEventHandler$ControllerImpl.class */
    public final class ControllerImpl implements RuntimeRescaleEventHandlerState.Controller {
        private ControllerImpl() {
        }

        @Override // org.apache.flink.streaming.runtime.io.rescaling.RuntimeRescaleEventHandlerState.Controller
        public boolean allRuntimeRescaleEventsReceived() {
            return BlockedRuntimeRescaleEventHandler.this.alignedChannels.size() == BlockedRuntimeRescaleEventHandler.this.targetChannelCount;
        }

        @Override // org.apache.flink.streaming.runtime.io.rescaling.RuntimeRescaleEventHandlerState.Controller
        public void migrateStatesIfNeedWhenAligned(RuntimeRescaleEvent runtimeRescaleEvent) throws IOException {
            BlockedRuntimeRescaleEventHandler.this.migrateStatesIfNeedWhenAligned(runtimeRescaleEvent);
        }

        @Override // org.apache.flink.streaming.runtime.io.rescaling.RuntimeRescaleEventHandlerState.Controller
        public void triggerRuntimeRescaleEventWhenAligned(RuntimeRescaleEvent runtimeRescaleEvent) throws IOException {
            BlockedRuntimeRescaleEventHandler.this.triggerRuntimeRescaleEventWhenAligned(runtimeRescaleEvent);
        }
    }

    public BlockedRuntimeRescaleEventHandler(String str, RescalableTask rescalableTask, RuntimeRescaleEventHandlerState runtimeRescaleEventHandlerState, CheckpointableInputs checkpointableInputs) {
        super(rescalableTask);
        this.currentRescaleEventId = -1L;
        this.alignedChannels = new HashSet();
        this.taskName = (String) Preconditions.checkNotNull(str);
        this.currentState = (RuntimeRescaleEventHandlerState) Preconditions.checkNotNull(runtimeRescaleEventHandlerState);
        this.context = new ControllerImpl();
        this.inputs = checkpointableInputs;
    }

    public static BlockedRuntimeRescaleEventHandler create(String str, RescalableTask rescalableTask, CheckpointableInputs checkpointableInputs) {
        return new BlockedRuntimeRescaleEventHandler(str, rescalableTask, new WaitingForFirstRuntimeRescaleEvent(checkpointableInputs), checkpointableInputs);
    }

    public void setTaskName(String str) {
        this.taskName = str;
    }

    @Override // org.apache.flink.streaming.runtime.io.rescaling.RuntimeRescaleEventHandler
    public void processRuntimeRescaleEvent(RuntimeRescaleEvent runtimeRescaleEvent, InputChannelInfo inputChannelInfo) throws IOException {
        long id = runtimeRescaleEvent.getId();
        LOG.debug("{}: Received runtime rescale event from channel {} @ {}.", new Object[]{this.taskName, inputChannelInfo, Long.valueOf(id)});
        if (this.currentRescaleEventId > id) {
            return;
        }
        checkNewRuntimeRescaleEvent(runtimeRescaleEvent);
        Preconditions.checkState(this.currentRescaleEventId == id);
        markRuntimeRescaleEventAlignedAndTransformState(inputChannelInfo, runtimeRescaleEvent);
    }

    protected void markRuntimeRescaleEventAlignedAndTransformState(InputChannelInfo inputChannelInfo, RuntimeRescaleEvent runtimeRescaleEvent) throws IOException {
        this.alignedChannels.add(inputChannelInfo);
        LOG.info("Aligned {} of {} channels for task {}. Aligned channels are: {}", new Object[]{Integer.valueOf(this.alignedChannels.size()), Integer.valueOf(this.targetChannelCount), this.taskName, this.alignedChannels});
        try {
            this.currentState = this.currentState.runtimeRescaleEventReceived(this.context, inputChannelInfo, runtimeRescaleEvent);
        } catch (RuntimeRescaleException e) {
            abortInternal(this.currentRescaleEventId, e);
        } catch (Exception e2) {
            ExceptionUtils.rethrowIOException(e2);
        }
        if (this.alignedChannels.size() == this.targetChannelCount) {
            this.alignedChannels.clear();
            LOG.debug("{}: All the channels are aligned for runtime rescale event {}.", this.taskName, Long.valueOf(this.currentRescaleEventId));
        }
    }

    private void abortInternal(long j, RuntimeRescaleException runtimeRescaleException) throws IOException {
        LOG.debug("{}: Aborting runtime rescale event {} after exception {}.", new Object[]{this.taskName, Long.valueOf(this.currentRescaleEventId), runtimeRescaleException});
        this.currentRescaleEventId = Math.max(j, this.currentRescaleEventId);
        this.alignedChannels.clear();
        this.targetChannelCount = 0;
        this.currentState = this.currentState.abort(j);
        notifyAbort(j, runtimeRescaleException);
    }

    private void migrateStatesIfNeedWhenAligned(RuntimeRescaleEvent runtimeRescaleEvent) throws IOException {
        LOG.debug("{}: Triggering states migration {} on the barrier announcement at {}.", new Object[]{this.taskName, Long.valueOf(runtimeRescaleEvent.getId()), Long.valueOf(runtimeRescaleEvent.getTimestamp())});
        notifyMigrateStatesIfNeedWhenAligned(runtimeRescaleEvent);
    }

    private void triggerRuntimeRescaleEventWhenAligned(RuntimeRescaleEvent runtimeRescaleEvent) throws IOException {
        LOG.debug("{}: Triggering runtime rescale event {} on the barrier announcement at {}.", new Object[]{this.taskName, Long.valueOf(runtimeRescaleEvent.getId()), Long.valueOf(runtimeRescaleEvent.getTimestamp())});
        notifyRuntimeRescaleEventWhenAligned(runtimeRescaleEvent);
    }

    private void checkNewRuntimeRescaleEvent(RuntimeRescaleEvent runtimeRescaleEvent) {
        long id = runtimeRescaleEvent.getId();
        if (id == this.currentRescaleEventId) {
            return;
        }
        this.currentRescaleEventId = id;
        this.alignedChannels.clear();
        this.targetChannelCount = (int) Arrays.stream(this.inputs.get()).filter(checkpointableInput -> {
            return ((checkpointableInput instanceof IndexedInputGate) && ((IndexedInputGate) checkpointableInput).isFinished()) ? false : true;
        }).mapToLong((v0) -> {
            return v0.getNumberOfUnfinishedInputChannels();
        }).sum();
    }
}
