package org.apache.flink.runtime.state.heap;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.state.StateEntry;
import org.apache.flink.runtime.state.StateRequestEntry;
import org.apache.flink.runtime.state.StateTransformationFunction;
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/runtime/state/heap/ComplexStateMap.class */
public class ComplexStateMap<K, N, S> extends StateMap<K, N, S> {
    private final CopyOnWriteStateMap<K, N, S> heapMap;
    private SpillStateMap<K, N, S> spillMap;
    private int requestCount;

    public ComplexStateMap(StateMap<K, N, S> stateMap) {
        Preconditions.checkState(stateMap instanceof CopyOnWriteStateMap, "Only CopyOnWriteStateMap can be use as heapMap");
        this.heapMap = (CopyOnWriteStateMap) stateMap;
        this.requestCount = 0;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void initSpillStateMap(Supplier<SpillStateMap<K, N, S>> supplier) {
        if (this.spillMap == null) {
            this.spillMap = supplier.get();
        }
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public int size() {
        return this.heapMap.size() + this.spillMap.size();
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public S get(K k, N n) {
        updateRequests();
        S s = this.heapMap.get(k, n);
        return s == null ? this.spillMap.get(k, n) : s;
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public boolean containsKey(K k, N n) {
        updateRequests();
        return this.heapMap.containsKey(k, n) || this.spillMap.containsKey(k, n);
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public void put(K k, N n, S s) {
        updateRequests();
        if (this.spillMap.containsKey(k, n)) {
            this.spillMap.put(k, n, s);
        } else {
            this.heapMap.put(k, n, s);
        }
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public S putAndGetOld(K k, N n, S s) {
        updateRequests();
        return this.spillMap.containsKey(k, n) ? this.spillMap.putAndGetOld(k, n, s) : this.heapMap.putAndGetOld(k, n, s);
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public void remove(K k, N n) {
        int size = this.heapMap.size();
        this.heapMap.remove(k, n);
        if (this.heapMap.size() < size) {
            removeRequests(this.heapMap.getRequestCountForLastRemovedKey());
        } else {
            this.spillMap.remove(k, n);
            removeRequests(this.spillMap.getRequestCountForLastRemovedKey());
        }
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public S removeAndGetOld(K k, N n) {
        S removeAndGetOld = this.heapMap.removeAndGetOld(k, n);
        if (removeAndGetOld != null) {
            removeRequests(this.heapMap.getRequestCountForLastRemovedKey());
        } else {
            removeAndGetOld = this.spillMap.removeAndGetOld(k, n);
            removeRequests(this.spillMap.getRequestCountForLastRemovedKey());
        }
        return removeAndGetOld;
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public <T> void transform(K k, N n, T t, StateTransformationFunction<S, T> stateTransformationFunction) throws Exception {
        updateRequests();
        if (this.spillMap.containsKey(k, n)) {
            this.spillMap.transform(k, n, t, stateTransformationFunction);
        } else {
            this.heapMap.transform(k, n, t, stateTransformationFunction);
        }
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public Stream<K> getKeys(N n) {
        return Stream.concat(this.heapMap.getKeys(n), this.spillMap.getKeys(n));
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public InternalKvState.StateIncrementalVisitor<K, N, S> getStateIncrementalVisitor(final int i) {
        return new InternalKvState.StateIncrementalVisitor<K, N, S>() { // from class: org.apache.flink.runtime.state.heap.ComplexStateMap.1
            private final InternalKvState.StateIncrementalVisitor<K, N, S> heapVisitor;
            private final InternalKvState.StateIncrementalVisitor<K, N, S> spillVisitor;

            {
                this.heapVisitor = ComplexStateMap.this.heapMap.getStateIncrementalVisitor(i);
                this.spillVisitor = ComplexStateMap.this.spillMap.getStateIncrementalVisitor(i);
            }

            @Override // org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor
            public boolean hasNext() {
                return this.heapVisitor.hasNext() || this.spillVisitor.hasNext();
            }

            @Override // org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor
            public Collection<StateEntry<K, N, S>> nextEntries() {
                return this.heapVisitor.hasNext() ? this.heapVisitor.nextEntries() : this.spillVisitor.nextEntries();
            }

            @Override // org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor
            public void remove(StateEntry<K, N, S> stateEntry) {
                ComplexStateMap.this.remove(stateEntry.getKey(), stateEntry.getNamespace());
            }

            @Override // org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor
            public void update(StateEntry<K, N, S> stateEntry, S s) {
                ComplexStateMap.this.put(stateEntry.getKey(), stateEntry.getNamespace(), s);
            }
        };
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    @Nonnull
    public ComplexStateMapSnapshot<K, N, S> stateSnapshot() {
        return new ComplexStateMapSnapshot<>(this, this.heapMap.stateSnapshot(), this.spillMap.stateSnapshot());
    }

    @Override // org.apache.flink.runtime.state.heap.StateMap
    public int sizeOfNamespace(Object obj) {
        return this.heapMap.sizeOfNamespace(obj) + this.spillMap.sizeOfNamespace(obj);
    }

    @Override // java.lang.Iterable
    public Iterator<StateEntry<K, N, S>> iterator() {
        return new Iterator<StateEntry<K, N, S>>() { // from class: org.apache.flink.runtime.state.heap.ComplexStateMap.2
            private final Iterator<StateEntry<K, N, S>> heapIterator;
            private final Iterator<StateEntry<K, N, S>> spillIterator;

            {
                this.heapIterator = ComplexStateMap.this.heapMap.iterator();
                this.spillIterator = ComplexStateMap.this.spillMap.iterator();
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.heapIterator.hasNext() || this.spillIterator.hasNext();
            }

            @Override // java.util.Iterator
            public StateEntry<K, N, S> next() {
                return this.heapIterator.hasNext() ? this.heapIterator.next() : this.spillIterator.next();
            }
        };
    }

    public int getRequestCount() {
        return this.requestCount;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    public int getHeapSize() {
        return this.heapMap.size();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float getSpilledRatio() {
        return this.spillMap.size() / size();
    }

    public void spillState() {
        transferState(this.heapMap, this.spillMap, false);
    }

    public void loadState() {
        transferState(this.spillMap, this.heapMap, true);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void transferState(StateMap<K, N, S> stateMap, StateMap<K, N, S> stateMap2, boolean z) {
        ArrayList<StateEntry> arrayList = new ArrayList();
        float size = this.requestCount / size();
        Iterator it = stateMap.iterator();
        while (it.hasNext()) {
            StateEntry stateEntry = (StateEntry) it.next();
            Preconditions.checkState(stateEntry instanceof StateRequestEntry, "Only StateRequestEntry for inner state maps");
            StateRequestEntry stateRequestEntry = (StateRequestEntry) stateEntry;
            if ((z && ((float) stateRequestEntry.getRequestCount()) > size) || (!z && ((float) stateRequestEntry.getRequestCount()) <= size)) {
                stateMap2.putWithRequestCount(stateRequestEntry.getKey(), stateRequestEntry.getNamespace(), stateRequestEntry.getState(), stateRequestEntry.getRequestCount());
                arrayList.add(new StateEntry.SimpleStateEntry(stateRequestEntry.getKey(), stateRequestEntry.getNamespace(), null));
            }
        }
        for (StateEntry stateEntry2 : arrayList) {
            stateMap.remove(stateEntry2.getKey(), stateEntry2.getNamespace());
        }
    }

    public void close() {
        this.spillMap.close();
    }

    boolean isClosed() {
        return this.spillMap.isClosed();
    }

    private void updateRequests() {
        this.requestCount++;
    }

    private void removeRequests(long j) {
        this.requestCount = (int) (this.requestCount - j);
    }
}
