package org.apache.flink.runtime.state.rescale;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.PartitionableListState;
import org.apache.flink.runtime.state.StateMigrationInfo;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/runtime/state/rescale/OperatorStateExtractor.class */
public class OperatorStateExtractor extends StateExtractor<DefaultOperatorStateBackend> {
    private final DefaultOperatorStateBackend backend;
    private final int subtaskId;
    private final int newParallelism;
    private final boolean isNewSubtask;
    private final boolean isDownscaledTask;
    private final StateMigrationInfo stateMigrationInfo;

    public OperatorStateExtractor(DefaultOperatorStateBackend defaultOperatorStateBackend, StateMigrationInfo stateMigrationInfo) {
        super(defaultOperatorStateBackend);
        this.backend = defaultOperatorStateBackend;
        this.subtaskId = stateMigrationInfo.getSubtaskId();
        this.stateMigrationInfo = stateMigrationInfo;
        this.isNewSubtask = this.subtaskId >= stateMigrationInfo.getTaskMigrationInfo().getOldParallelism();
        this.isDownscaledTask = stateMigrationInfo.getTaskMigrationInfo().getSubtaskInfoById(this.subtaskId).isRemovedState();
        this.newParallelism = stateMigrationInfo.getTaskMigrationInfo().getNewParallelism();
    }

    @Override // org.apache.flink.runtime.state.rescale.StateExtractor
    public void extractStatesFromStorage(String str, String str2, String str3, int i, RuntimeRescaleStreamFactory runtimeRescaleStreamFactory) throws Exception {
        InputStream createStateInputStream;
        if (this.isDownscaledTask) {
            return;
        }
        if (this.isNewSubtask && !this.backend.getRegisteredBroadcastStates().isEmpty()) {
            String joinPaths = joinPaths(joinPaths(str2, str3), MigrationKeyedStateInfo.toBroadcastPathName(0));
            waitStoragePathAvailability(runtimeRescaleStreamFactory, joinPaths);
            createStateInputStream = runtimeRescaleStreamFactory.createStateInputStream(joinPaths);
            try {
                DataInputViewStreamWrapper dataInputViewStreamWrapper = new DataInputViewStreamWrapper(createStateInputStream);
                int readInt = dataInputViewStreamWrapper.readInt();
                for (int i2 = 0; i2 < readInt; i2++) {
                    String readUTF = dataInputViewStreamWrapper.readUTF();
                    Preconditions.checkState(this.backend.getRegisteredBroadcastStates().containsKey(readUTF), "broadcast state '{}' missing", new Object[]{readUTF});
                    this.backend.getRegisteredBroadcastStates().get(readUTF).read(createStateInputStream);
                }
                if (createStateInputStream != null) {
                    createStateInputStream.close();
                }
            } finally {
            }
        }
        if (this.backend.getRegisteredOperatorStates().isEmpty()) {
            return;
        }
        this.backend.getRegisteredOperatorStates().forEach((str4, partitionableListState) -> {
            partitionableListState.clear();
        });
        int[] array = this.stateMigrationInfo.getTaskMigrationInfo().getSubtaskInfos().stream().filter(rescaledSubtaskMigrationInfo -> {
            return !rescaledSubtaskMigrationInfo.getKgrBeforeRescale().equals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE);
        }).mapToInt((v0) -> {
            return v0.getIdx();
        }).sorted().toArray();
        HashMap hashMap = new HashMap();
        for (int i3 : array) {
            String joinPaths2 = joinPaths(joinPaths(str2, str3), MigrationKeyedStateInfo.toOperatorStatePathName(i3));
            waitStoragePathAvailability(runtimeRescaleStreamFactory, joinPaths2);
            createStateInputStream = runtimeRescaleStreamFactory.createStateInputStream(joinPaths2);
            try {
                DataInputViewStreamWrapper dataInputViewStreamWrapper2 = new DataInputViewStreamWrapper(createStateInputStream);
                int readInt2 = dataInputViewStreamWrapper2.readInt();
                for (int i4 = 0; i4 < readInt2; i4++) {
                    String readUTF2 = dataInputViewStreamWrapper2.readUTF();
                    Preconditions.checkState(this.backend.getRegisteredOperatorStates().containsKey(readUTF2), "operator state '%s' missing", new Object[]{readUTF2});
                    PartitionableListState<?> partitionableListState2 = this.backend.getRegisteredOperatorStates().get(readUTF2);
                    partitionableListState2.read(createStateInputStream);
                    if (partitionableListState2.getStateMetaInfo().getAssignmentMode() == OperatorStateHandle.Mode.SPLIT_DISTRIBUTE) {
                        if (!hashMap.containsKey(readUTF2)) {
                            hashMap.put(readUTF2, new ArrayList());
                        }
                        List list = (List) hashMap.get(readUTF2);
                        if (list.isEmpty()) {
                            list.add(Tuple2.of(0, Integer.valueOf(partitionableListState2.size())));
                        } else {
                            list.add(Tuple2.of((Integer) ((Tuple2) list.get(list.size() - 1)).f1, Integer.valueOf(partitionableListState2.size())));
                        }
                    }
                }
                if (createStateInputStream != null) {
                    createStateInputStream.close();
                }
            } finally {
            }
        }
        int i5 = 0;
        for (Map.Entry<String, PartitionableListState<?>> entry : this.backend.getRegisteredOperatorStates().entrySet()) {
            String key = entry.getKey();
            PartitionableListState<?> value = entry.getValue();
            if (value.getStateMetaInfo().getAssignmentMode() == OperatorStateHandle.Mode.SPLIT_DISTRIBUTE) {
                i5 = repartitionSplitState(value, (List) hashMap.get(key), this.newParallelism, this.subtaskId, i5);
            }
        }
    }

    private int repartitionSplitState(PartitionableListState<?> partitionableListState, List<Tuple2<Integer, Integer>> list, int i, int i2, int i3) {
        Tuple2 of;
        int size = partitionableListState.size();
        int i4 = 0;
        int i5 = 0;
        int i6 = size / i;
        int i7 = size % i;
        int i8 = i3;
        HashMap hashMap = new HashMap();
        for (int i9 = 0; i9 < i; i9++) {
            int i10 = (i9 + i3) % i;
            int i11 = i6;
            if (i7 > 0) {
                i11++;
                i7--;
            } else if (i7 == 0) {
                i8 = i10;
                i7--;
            }
            while (i11 > 0) {
                Tuple2<Integer, Integer> tuple2 = list.get(i4);
                int intValue = (((Integer) tuple2.f1).intValue() - ((Integer) tuple2.f0).intValue()) - i5;
                if (intValue > i11) {
                    of = Tuple2.of(Integer.valueOf(((Integer) tuple2.f0).intValue() + i5), Integer.valueOf(((Integer) tuple2.f0).intValue() + i5 + i11));
                    i5 += i11;
                } else {
                    of = Tuple2.of(Integer.valueOf(((Integer) tuple2.f0).intValue() + i5), (Integer) tuple2.f1);
                    i5 = 0;
                    i4++;
                }
                i11 -= intValue;
                ((List) hashMap.computeIfAbsent(Integer.valueOf(i10), num -> {
                    return new ArrayList();
                })).add(of);
            }
        }
        int i12 = i8;
        if (hashMap.containsKey(Integer.valueOf(i2))) {
            Tuple2 tuple22 = (Tuple2) ((List) hashMap.get(Integer.valueOf(i2))).stream().reduce((tuple23, tuple24) -> {
                return Tuple2.of(Integer.valueOf(Math.min(((Integer) tuple23.f0).intValue(), ((Integer) tuple24.f0).intValue())), Integer.valueOf(Math.max(((Integer) tuple23.f1).intValue(), ((Integer) tuple24.f1).intValue())));
            }).orElseThrow(RuntimeException::new);
            partitionableListState.repartition(((Integer) tuple22.f0).intValue(), ((Integer) tuple22.f1).intValue());
        } else {
            partitionableListState.clear();
        }
        return i12;
    }
}
