package org.apache.flink.runtime.state.heap;

import java.math.BigDecimal;
import java.time.Duration;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.heap.HeapStatusMonitor;
import org.apache.flink.runtime.state.heap.SpillableStateTable;
import org.apache.flink.shaded.guava30.com.google.common.base.Objects;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/state/heap/SpillAndLoadManagerImpl.class */
public class SpillAndLoadManagerImpl implements SpillAndLoadManager {
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) SpillAndLoadManagerImpl.class);
    private static final double WEIGHT_SPILL_RETAINED_SIZE = 0.7d;
    private static final double WEIGHT_SPILL_REQUEST_RATE = -0.3d;
    private static final double WEIGHT_SPILL_SUM = 0.39999999999999997d;
    private static final double WEIGHT_LOAD_RETAINED_SIZE = -0.3d;
    private static final double WEIGHT_LOAD_REQUEST_RATE = 0.7d;
    private static final double WEIGHT_LOAD_SUM = 0.39999999999999997d;
    private final StateTableContainer stateTableContainer;
    private final HeapStatusMonitor heapStatusMonitor;
    private final CheckpointManager checkpointManager;
    private final boolean cancelCheckpoint;
    private final long gcTimeThreshold;
    private final float spillSizeRatio;
    private final float loadStartRatio;
    private final float loadEndRatio;
    private final long triggerInterval;
    private final long resourceCheckInterval;
    private final long maxMemory;
    private final long spillStartSize;
    private final long loadStartSize;
    private final long loadEndSize;
    private long lastResourceCheckTime;
    private long lastTriggerTime;
    private HeapStatusMonitor.MonitorResult lastMonitorResult;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/flink/runtime/state/heap/SpillAndLoadManagerImpl$Action.class */
    public enum Action {
        NONE,
        SPILL,
        LOAD
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/flink/runtime/state/heap/SpillAndLoadManagerImpl$ActionResult.class */
    public static class ActionResult {
        Action action;
        float spillOrLoadRatio;

        ActionResult(Action action, float f) {
            this.action = action;
            this.spillOrLoadRatio = f;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            ActionResult actionResult = (ActionResult) obj;
            return this.action == actionResult.action && this.spillOrLoadRatio == actionResult.spillOrLoadRatio;
        }

        public int hashCode() {
            return Objects.hashCode(this.action, Float.valueOf(this.spillOrLoadRatio));
        }

        public String toString() {
            return "ActionResult{action=" + this.action + ", spillOrLoadRatio=" + this.spillOrLoadRatio + '}';
        }

        static ActionResult ofNone() {
            return new ActionResult(Action.NONE, 0.0f);
        }

        static ActionResult ofSpill(float f) {
            return new ActionResult(Action.SPILL, f);
        }

        static ActionResult ofLoad(float f) {
            return new ActionResult(Action.LOAD, f);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/flink/runtime/state/heap/SpillAndLoadManagerImpl$StateTableContainer.class */
    public interface StateTableContainer extends Iterable<Tuple2<String, SpillableStateTable>> {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/flink/runtime/state/heap/SpillAndLoadManagerImpl$StateTableContainerImpl.class */
    public static class StateTableContainerImpl<K> implements StateTableContainer {
        private final Map<String, StateTable<K, ?, ?>> registeredKVStates;

        public StateTableContainerImpl(Map<String, StateTable<K, ?, ?>> map) {
            this.registeredKVStates = map;
        }

        @Override // java.lang.Iterable
        public Iterator<Tuple2<String, SpillableStateTable>> iterator() {
            return new Iterator<Tuple2<String, SpillableStateTable>>() { // from class: org.apache.flink.runtime.state.heap.SpillAndLoadManagerImpl.StateTableContainerImpl.1
                private final Iterator<Map.Entry<String, StateTable<K, ?, ?>>> iter;

                {
                    this.iter = StateTableContainerImpl.this.registeredKVStates.entrySet().iterator();
                }

                @Override // java.util.Iterator
                public boolean hasNext() {
                    return this.iter.hasNext();
                }

                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.Iterator
                public Tuple2<String, SpillableStateTable> next() {
                    Map.Entry<String, StateTable<K, ?, ?>> next = this.iter.next();
                    return Tuple2.of(next.getKey(), (SpillableStateTable) next.getValue());
                }
            };
        }
    }

    public SpillAndLoadManagerImpl(StateTableContainer stateTableContainer, HeapStatusMonitor heapStatusMonitor, CheckpointManager<KeyedStateHandle> checkpointManager, Configuration configuration) {
        this.stateTableContainer = (StateTableContainer) Preconditions.checkNotNull(stateTableContainer);
        this.heapStatusMonitor = (HeapStatusMonitor) Preconditions.checkNotNull(heapStatusMonitor);
        this.checkpointManager = (CheckpointManager) Preconditions.checkNotNull(checkpointManager);
        this.cancelCheckpoint = ((Boolean) configuration.get(SpillableOptions.CANCEL_CHECKPOINT)).booleanValue();
        this.gcTimeThreshold = ((Duration) configuration.get(SpillableOptions.GC_TIME_THRESHOLD)).toMillis();
        float floatValue = ((Float) configuration.get(SpillableOptions.LOAD_START_RATIO)).floatValue();
        float floatValue2 = ((Float) configuration.get(SpillableOptions.LOAD_END_RATIO)).floatValue();
        if (floatValue >= floatValue2) {
            LOG.warn("Load start ratio {} >= end ratio {} even with adjustment, will use default (startRatio={}, endRatio={}) instead", Float.valueOf(floatValue), Float.valueOf(floatValue2), SpillableOptions.LOAD_START_RATIO.defaultValue(), SpillableOptions.LOAD_END_RATIO.defaultValue());
            floatValue = SpillableOptions.LOAD_START_RATIO.defaultValue().floatValue();
            floatValue2 = SpillableOptions.LOAD_END_RATIO.defaultValue().floatValue();
        }
        this.loadStartRatio = floatValue;
        this.loadEndRatio = floatValue2;
        float floatValue3 = ((Float) configuration.get(SpillableOptions.SPILL_START_RATIO)).floatValue();
        this.spillSizeRatio = ((Float) configuration.get(SpillableOptions.SPILL_SIZE_RATIO)).floatValue();
        this.triggerInterval = ((Duration) configuration.get(SpillableOptions.TRIGGER_INTERVAL)).toMillis();
        this.resourceCheckInterval = ((Duration) configuration.get(SpillableOptions.RESOURCE_CHECK_INTERVAL)).toMillis();
        this.maxMemory = heapStatusMonitor.getMaxMemory();
        this.spillStartSize = ((float) this.maxMemory) * floatValue3;
        this.loadStartSize = ((float) this.maxMemory) * this.loadStartRatio;
        this.loadEndSize = ((float) this.maxMemory) * this.loadEndRatio;
        this.lastResourceCheckTime = System.currentTimeMillis();
        this.lastTriggerTime = System.currentTimeMillis();
    }

    @Override // org.apache.flink.runtime.state.heap.SpillAndLoadManager
    public void checkResource() {
        long currentTimeMillis = System.currentTimeMillis();
        if (currentTimeMillis - this.lastResourceCheckTime < this.resourceCheckInterval) {
            return;
        }
        this.lastResourceCheckTime = currentTimeMillis;
        HeapStatusMonitor.MonitorResult resetAndGetMonitorResult = this.heapStatusMonitor.resetAndGetMonitorResult();
        LOG.debug("Update monitor result {}", resetAndGetMonitorResult);
        if (this.lastMonitorResult == null || this.lastMonitorResult.getId() != resetAndGetMonitorResult.getId()) {
            this.lastMonitorResult = resetAndGetMonitorResult;
            ActionResult decideAction = decideAction(resetAndGetMonitorResult);
            LOG.debug("Decide action {}", decideAction);
            if (decideAction.action == Action.NONE) {
                return;
            }
            if (resetAndGetMonitorResult.getTimestamp() - this.lastTriggerTime < this.triggerInterval) {
                LOG.debug("Too frequent to spill/load, last time is {}", Long.valueOf(this.lastTriggerTime));
                return;
            }
            if (decideAction.action == Action.SPILL) {
                doSpill(decideAction);
            } else {
                doLoad(decideAction);
            }
            this.lastTriggerTime = System.currentTimeMillis();
        }
    }

    @Override // org.apache.flink.runtime.state.heap.SpillAndLoadManager
    public float getSpilledRatio() {
        int i = 0;
        float f = 0.0f;
        Iterator<Tuple2<String, SpillableStateTable>> it = this.stateTableContainer.iterator();
        while (it.hasNext()) {
            f += it.next().f1.getSpilledRatio();
            i++;
        }
        return f / i;
    }

    @VisibleForTesting
    ActionResult decideAction(HeapStatusMonitor.MonitorResult monitorResult) {
        long garbageCollectionTime = monitorResult.getGarbageCollectionTime();
        long totalUsedMemory = monitorResult.getTotalUsedMemory();
        return (totalUsedMemory > this.spillStartSize || garbageCollectionTime > this.gcTimeThreshold) ? ActionResult.ofSpill(this.spillSizeRatio) : totalUsedMemory < this.loadStartSize ? ActionResult.ofLoad(((float) (this.loadEndSize - totalUsedMemory)) / ((float) totalUsedMemory)) : ActionResult.ofNone();
    }

    @VisibleForTesting
    public void externalDoSpill(Float f) {
        doSpill(ActionResult.ofSpill(f.floatValue()));
    }

    private void doSpill(ActionResult actionResult) {
        List<SpillableStateTable.StateMapMeta> stateMapMetas = getStateMapMetas(stateMapMeta -> {
            return Boolean.valueOf((stateMapMeta.isOnHeap() || stateMapMeta.isComplex()) && stateMapMeta.getSize() > 0);
        });
        if (stateMapMetas.isEmpty()) {
            LOG.debug("There is no StateMap to spill.");
            return;
        }
        sortStateMapMeta(actionResult.action, stateMapMetas);
        long longValue = ((float) ((Long) stateMapMetas.stream().map((v0) -> {
            return v0.getEstimatedMemorySize();
        }).reduce(0L, (l, l2) -> {
            return Long.valueOf(l.longValue() + l2.longValue());
        })).longValue()) * actionResult.spillOrLoadRatio;
        if (longValue == 0) {
            return;
        }
        if (this.cancelCheckpoint) {
            this.checkpointManager.cancelAllCheckpoints();
        }
        for (SpillableStateTable.StateMapMeta stateMapMeta2 : stateMapMetas) {
            long j = 0;
            if (stateMapMeta2.isComplex()) {
                Preconditions.checkState(stateMapMeta2.getStateTable() instanceof ComplexSpillableStateTableImpl, "Only ComplexSpillableStateTableImpl can be here");
                j = ((ComplexStateMap) stateMapMeta2.getStateTable().getMapForKeyGroup(stateMapMeta2.getKeyGroupIndex())).getHeapSize();
            }
            stateMapMeta2.getStateTable().spillState(stateMapMeta2.getKeyGroupIndex());
            LOG.debug("Spill state in key group {} successfully", Integer.valueOf(stateMapMeta2.getKeyGroupIndex()));
            longValue = stateMapMeta2.isComplex() ? longValue - ((j - ((ComplexStateMap) stateMapMeta2.getStateTable().getMapForKeyGroup(stateMapMeta2.getKeyGroupIndex())).getHeapSize()) * stateMapMeta2.getStateTable().getStateEstimatedSize(true)) : longValue - stateMapMeta2.getEstimatedMemorySize();
            if (longValue <= 0) {
                return;
            }
        }
    }

    @VisibleForTesting
    public void externalDoLoad(Float f) {
        doLoad(ActionResult.ofLoad(f.floatValue()));
    }

    private void doLoad(ActionResult actionResult) {
        List<SpillableStateTable.StateMapMeta> stateMapMetas = getStateMapMetas(stateMapMeta -> {
            return Boolean.valueOf((!stateMapMeta.isOnHeap() || stateMapMeta.isComplex()) && stateMapMeta.getSize() > 0);
        });
        if (stateMapMetas.isEmpty()) {
            LOG.debug("There is no StateMap to load.");
            return;
        }
        sortStateMapMeta(actionResult.action, stateMapMetas);
        long longValue = ((float) ((Long) stateMapMetas.stream().map((v0) -> {
            return v0.getEstimatedMemorySize();
        }).reduce(0L, (l, l2) -> {
            return Long.valueOf(l.longValue() + l2.longValue());
        })).longValue()) * actionResult.spillOrLoadRatio;
        if (longValue == 0) {
            return;
        }
        for (SpillableStateTable.StateMapMeta stateMapMeta2 : stateMapMetas) {
            longValue -= stateMapMeta2.getEstimatedMemorySize();
            if (longValue < 0) {
                return;
            }
            stateMapMeta2.getStateTable().loadState(stateMapMeta2.getKeyGroupIndex());
            LOG.debug("Load state in key group {} successfully", Integer.valueOf(stateMapMeta2.getKeyGroupIndex()));
        }
    }

    private List<SpillableStateTable.StateMapMeta> getStateMapMetas(Function<SpillableStateTable.StateMapMeta, Boolean> function) {
        ArrayList arrayList = new ArrayList();
        for (Tuple2<String, SpillableStateTable> tuple2 : this.stateTableContainer) {
            int size = arrayList.size();
            SpillableStateTable spillableStateTable = tuple2.f1;
            Iterator<SpillableStateTable.StateMapMeta> stateMapIterator = spillableStateTable.stateMapIterator();
            while (stateMapIterator.hasNext()) {
                SpillableStateTable.StateMapMeta next = stateMapIterator.next();
                if (function.apply(next).booleanValue()) {
                    arrayList.add(next);
                }
            }
            if (size < arrayList.size()) {
                long stateEstimatedSize = spillableStateTable.getStateEstimatedSize(true);
                Preconditions.checkState(stateEstimatedSize >= 0, "state estimated size should be positive but is {}", Long.valueOf(stateEstimatedSize));
                for (int i = size; i < arrayList.size(); i++) {
                    ((SpillableStateTable.StateMapMeta) arrayList.get(i)).setEstimatedMemorySize(r0.getSize() * stateEstimatedSize);
                }
            }
        }
        return arrayList;
    }

    private void sortStateMapMeta(Action action, List<SpillableStateTable.StateMapMeta> list) {
        if (list.isEmpty()) {
            return;
        }
        long j = 0;
        long j2 = Long.MAX_VALUE;
        long j3 = 0;
        long j4 = Long.MAX_VALUE;
        for (SpillableStateTable.StateMapMeta stateMapMeta : list) {
            long estimatedMemorySize = stateMapMeta.getEstimatedMemorySize();
            j = Math.max(j, estimatedMemorySize);
            j2 = Math.min(j2, estimatedMemorySize);
            long numRequests = stateMapMeta.getNumRequests();
            j3 = Math.max(j3, numRequests);
            j4 = Math.min(j4, numRequests);
        }
        long j5 = j - j2;
        long j6 = j3 - j4;
        long j7 = j2;
        long j8 = j4;
        IdentityHashMap identityHashMap = new IdentityHashMap();
        list.sort((stateMapMeta2, stateMapMeta3) -> {
            if (stateMapMeta2 == stateMapMeta3) {
                return 0;
            }
            if (stateMapMeta2 == null) {
                return -1;
            }
            return (stateMapMeta3 != null && ((Double) identityHashMap.computeIfAbsent(stateMapMeta2, stateMapMeta2 -> {
                return Double.valueOf(computeWeight(stateMapMeta2, action, j7, j8, j5, j6));
            })).doubleValue() > ((Double) identityHashMap.computeIfAbsent(stateMapMeta3, stateMapMeta3 -> {
                return Double.valueOf(computeWeight(stateMapMeta3, action, j7, j8, j5, j6));
            })).doubleValue()) ? -1 : 1;
        });
    }

    private double computeWeight(SpillableStateTable.StateMapMeta stateMapMeta, Action action, long j, long j2, long j3, long j4) {
        double d;
        double d2;
        double d3;
        double estimatedMemorySize = j3 == 0 ? CMAESOptimizer.DEFAULT_STOPFITNESS : (stateMapMeta.getEstimatedMemorySize() - j) / j3;
        double numRequests = j4 == 0 ? CMAESOptimizer.DEFAULT_STOPFITNESS : (stateMapMeta.getNumRequests() - j2) / j4;
        switch (action) {
            case SPILL:
                d = 0.7d;
                d2 = -0.3d;
                d3 = 0.39999999999999997d;
                break;
            case LOAD:
                d = -0.3d;
                d2 = 0.7d;
                d3 = 0.39999999999999997d;
                break;
            default:
                throw new RuntimeException("Unsupported action: " + action);
        }
        return ((d * estimatedMemorySize) + (d2 * numRequests)) / d3;
    }

    private float floatSum(float f, float f2) {
        return new BigDecimal(Float.toString(f)).add(new BigDecimal(Float.toString(f2))).floatValue();
    }

    private float floatSub(float f, float f2) {
        return new BigDecimal(Float.toString(f)).subtract(new BigDecimal(Float.toString(f2))).floatValue();
    }

    @VisibleForTesting
    long getGcTimeThreshold() {
        return this.gcTimeThreshold;
    }

    @VisibleForTesting
    long getTriggerInterval() {
        return this.triggerInterval;
    }

    @VisibleForTesting
    long getResourceCheckInterval() {
        return this.resourceCheckInterval;
    }

    @VisibleForTesting
    long getMaxMemory() {
        return this.maxMemory;
    }

    public float getSpillSizeRatio() {
        return this.spillSizeRatio;
    }

    @VisibleForTesting
    long getLoadStartSize() {
        return this.loadStartSize;
    }

    @VisibleForTesting
    long getLoadEndSize() {
        return this.loadEndSize;
    }
}
