package io.prestosql.snapshot;

import com.google.common.base.Preconditions;
import com.google.common.collect.Iterators;
import io.airlift.log.Logger;
import io.hetu.core.transport.execution.buffer.PagesSerde;
import io.hetu.core.transport.execution.buffer.SerializedPage;
import io.prestosql.operator.OperatorContext;
import io.prestosql.operator.TaskContext;
import io.prestosql.spi.Page;
import io.prestosql.spi.snapshot.MarkerPage;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;

/* loaded from: input_file:io/prestosql/snapshot/MultiInputSnapshotState.class */
public class MultiInputSnapshotState {
    private static final Logger LOG = Logger.get(MultiInputSnapshotState.class);
    private final MultiInputRestorable restorable;
    private final String restorableId;
    private final TaskSnapshotManager snapshotManager;
    private final PagesSerde pagesSerde;
    private final Function<Long, SnapshotStateId> snapshotStateIdGenerator;
    private Optional<Set<String>> inputChannels;
    private Iterator<?> pendingPages;
    private final List<SnapshotState> states = new ArrayList();
    private final List<Page> pendingMarkers = new ArrayList();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/snapshot/MultiInputSnapshotState$SnapshotState.class */
    public static class SnapshotState {
        private final long snapshotId;
        private final boolean resuming;
        private final Set<String> markedChannels;
        private final List<Object> states;

        private SnapshotState(MarkerPage markerPage) {
            this.states = new ArrayList();
            this.snapshotId = markerPage.getSnapshotId();
            this.resuming = markerPage.isResuming();
            this.markedChannels = new HashSet();
        }
    }

    public static MultiInputSnapshotState forOperator(MultiInputRestorable multiInputRestorable, OperatorContext operatorContext) {
        return new MultiInputSnapshotState(multiInputRestorable, operatorContext.getDriverContext().getPipelineContext().getTaskContext().getSnapshotManager(), operatorContext.getDriverContext().getSerde(), l -> {
            return SnapshotStateId.forOperator(l.longValue(), operatorContext);
        });
    }

    public static MultiInputSnapshotState forTaskComponent(MultiInputRestorable multiInputRestorable, TaskContext taskContext, Function<Long, SnapshotStateId> function) {
        return new MultiInputSnapshotState(multiInputRestorable, taskContext.getSnapshotManager(), taskContext.getSerdeFactory().createPagesSerde(), function);
    }

    public MultiInputSnapshotState(MultiInputRestorable multiInputRestorable, TaskSnapshotManager taskSnapshotManager, PagesSerde pagesSerde, Function<Long, SnapshotStateId> function) {
        this.restorable = multiInputRestorable;
        this.restorableId = String.format("%s (%s)", multiInputRestorable.getClass().getSimpleName(), function.apply(0L).getId());
        this.snapshotManager = taskSnapshotManager;
        this.pagesSerde = pagesSerde;
        this.snapshotStateIdGenerator = function;
    }

    public Optional<Page> processPage(Supplier<Page> supplier) {
        return Optional.ofNullable(processPage(supplier, false, false));
    }

    public Optional<SerializedPage> processSerializedPage(Supplier<SerializedPage> supplier) {
        return Optional.ofNullable(processPage(supplier, true, false));
    }

    public Optional<Page> nextMarker(Supplier<Page> supplier) {
        return Optional.ofNullable(processPage(supplier, false, true));
    }

    public Optional<SerializedPage> nextSerializedMarker(Supplier<SerializedPage> supplier) {
        return Optional.ofNullable(processPage(supplier, true, true));
    }

    public List<Page> processPages(List<Page> list) {
        Iterator<Page> it = list.iterator();
        Supplier supplier = () -> {
            if (it.hasNext()) {
                return (Page) it.next();
            }
            return null;
        };
        it.getClass();
        return processPages(supplier, it::hasNext, false);
    }

    public List<SerializedPage> processSerializedPages(List<SerializedPage> list) {
        Iterator<SerializedPage> it = list.iterator();
        Supplier supplier = () -> {
            if (it.hasNext()) {
                return (SerializedPage) it.next();
            }
            return null;
        };
        it.getClass();
        return processPages(supplier, it::hasNext, true);
    }

    private <T> List<T> processPages(Supplier<T> supplier, Supplier<Boolean> supplier2, boolean z) {
        ArrayList arrayList = new ArrayList();
        while (true) {
            Object processPage = processPage(supplier, z, false);
            if (processPage != null) {
                arrayList.add(processPage);
            }
            if (processPage == null && !supplier2.get().booleanValue()) {
                return arrayList;
            }
        }
    }

    private <T> T processPage(Supplier<T> supplier, boolean z, boolean z2) {
        Page deserialize;
        Optional<String> origin;
        if (z2 && this.pendingPages != null && this.pendingPages.hasNext()) {
            return null;
        }
        SerializedPage pollPendingPage = pollPendingPage();
        if (pollPendingPage == null) {
            T t = supplier.get();
            if (t == null) {
                return null;
            }
            if (t instanceof SerializedPage) {
                pollPendingPage = (SerializedPage) t;
                origin = pollPendingPage.getOrigin();
                deserialize = this.pagesSerde.deserialize(pollPendingPage);
            } else {
                deserialize = (Page) t;
                origin = deserialize.getOrigin();
            }
            if (processNewPage(origin, deserialize)) {
                return null;
            }
            if (z2) {
                if (deserialize instanceof MarkerPage) {
                    return pollPendingPage != null ? (T) this.pagesSerde.serialize(deserialize) : (T) deserialize;
                }
                this.pendingPages = Iterators.singletonIterator(this.pagesSerde.serialize(deserialize));
                return null;
            }
        } else {
            deserialize = this.pagesSerde.deserialize(pollPendingPage);
        }
        return z ? (T) this.pagesSerde.serialize(deserialize) : (T) deserialize;
    }

    private boolean processNewPage(Optional<String> optional, Page page) {
        MarkerPage markerPage = null;
        if (page instanceof MarkerPage) {
            markerPage = (MarkerPage) page;
        }
        if (!optional.isPresent()) {
            Preconditions.checkState(markerPage != null);
            LOG.debug("Sending marker '%s' directly to target '%s'", new Object[]{markerPage.toString(), this.restorableId});
            return false;
        }
        if (markerPage == null) {
            return processInput(optional.get(), page);
        }
        LOG.debug("Received marker '%s' from source '%s' to target '%s'", new Object[]{markerPage.toString(), optional.get(), this.restorableId});
        if (markerPage.isResuming()) {
            boolean resume = resume(optional.get(), markerPage);
            if (!resume) {
                LOG.debug("Sending resume marker '%s' from source '%s' to target '%s'", new Object[]{markerPage.toString(), optional.get(), this.restorableId});
            }
            return resume;
        }
        boolean processMarker = processMarker(optional.get(), markerPage);
        if (!processMarker) {
            LOG.debug("Sending marker '%s' from source '%s' to target '%s'", new Object[]{markerPage.toString(), optional.get(), this.restorableId});
        }
        return processMarker;
    }

    private SnapshotState snapshotStateById(MarkerPage markerPage) {
        for (SnapshotState snapshotState : this.states) {
            if (snapshotState.snapshotId == markerPage.getSnapshotId() && snapshotState.resuming == markerPage.isResuming()) {
                return snapshotState;
            }
        }
        return null;
    }

    private boolean resume(String str, MarkerPage markerPage) {
        boolean z = false;
        long snapshotId = markerPage.getSnapshotId();
        SnapshotState snapshotStateById = snapshotStateById(markerPage);
        SnapshotStateId apply = this.snapshotStateIdGenerator.apply(Long.valueOf(snapshotId));
        if (snapshotStateById == null) {
            this.states.clear();
            try {
                Optional<Object> loadState = this.snapshotManager.loadState(apply);
                if (!loadState.isPresent()) {
                    this.snapshotManager.failedToRestore(apply, true);
                    LOG.warn("Can't locate saved state for snapshot %d, component %s", new Object[]{Long.valueOf(snapshotId), this.restorableId});
                } else if (loadState.get() == TaskSnapshotManager.NO_STATE) {
                    this.snapshotManager.failedToRestore(apply, true);
                    LOG.error("BUG! State of component %s has never been stored successfully before snapshot %d", new Object[]{this.restorableId, Long.valueOf(snapshotId)});
                } else {
                    this.pendingPages = ((List) loadState.get()).listIterator();
                    this.restorable.restore(this.pendingPages.next(), this.pagesSerde);
                    LOG.debug("Successfully restored state to snapshot %d for %s", new Object[]{Long.valueOf(snapshotId), this.restorableId});
                    this.snapshotManager.succeededToRestore(apply);
                }
            } catch (Exception e) {
                LOG.warn(e, "Failed to restore snapshot state for %s: %s", new Object[]{apply, e.getMessage()});
                this.snapshotManager.failedToRestore(apply, false);
            }
            snapshotStateById = new SnapshotState(markerPage);
            this.states.add(snapshotStateById);
        } else {
            z = true;
        }
        if (!snapshotStateById.markedChannels.add(str)) {
            LOG.error(String.format("Received duplicate marker '%s' from source '%s' to target '%s'", markerPage.toString(), str, this.restorableId));
        }
        this.inputChannels = this.restorable.getInputChannels(markerPage.getTaskCount());
        markerPage.setTaskCount(0);
        if (this.inputChannels.isPresent()) {
            Preconditions.checkState(this.inputChannels.get().containsAll(snapshotStateById.markedChannels));
            if (this.inputChannels.get().size() == snapshotStateById.markedChannels.size()) {
                Preconditions.checkState(this.states.get(0) == snapshotStateById, "resume state should be the first one");
                this.states.remove(0);
            }
        }
        return z;
    }

    private boolean processMarker(String str, MarkerPage markerPage) {
        boolean z = false;
        long snapshotId = markerPage.getSnapshotId();
        SnapshotState snapshotStateById = snapshotStateById(markerPage);
        if (snapshotStateById == null) {
            snapshotStateById = new SnapshotState(markerPage);
            this.pendingMarkers.add(markerPage);
            try {
                snapshotStateById.states.add(this.restorable.capture(this.pagesSerde));
            } catch (Exception e) {
                LOG.warn(e, "Failed to capture and store snapshot state");
                this.snapshotManager.failedToCapture(this.snapshotStateIdGenerator.apply(Long.valueOf(snapshotId)));
            }
            this.states.add(snapshotStateById);
        } else {
            z = true;
        }
        if (!snapshotStateById.markedChannels.add(str)) {
            LOG.error(String.format("Received duplicate marker '%s' from source '%s' to target '%s'", markerPage.toString(), str, this.restorableId));
            return true;
        }
        this.inputChannels = this.restorable.getInputChannels(markerPage.getTaskCount());
        markerPage.setTaskCount(0);
        if (this.inputChannels.isPresent()) {
            Preconditions.checkState(this.inputChannels.get().containsAll(snapshotStateById.markedChannels));
            if (this.inputChannels.get().size() == snapshotStateById.markedChannels.size()) {
                SnapshotStateId apply = this.snapshotStateIdGenerator.apply(Long.valueOf(snapshotId));
                try {
                    this.snapshotManager.storeState(apply, snapshotStateById.states);
                    this.snapshotManager.succeededToCapture(apply);
                    LOG.debug("Successfully saved state to snapshot %d for %s", new Object[]{Long.valueOf(snapshotId), this.restorableId});
                } catch (Exception e2) {
                    LOG.warn(e2, "Failed to capture and store snapshot state");
                    this.snapshotManager.failedToCapture(apply);
                }
                int indexOf = this.states.indexOf(snapshotStateById);
                this.states.remove(indexOf);
                for (int i = 0; i < indexOf; i++) {
                    SnapshotState remove = this.states.remove(0);
                    SnapshotStateId apply2 = this.snapshotStateIdGenerator.apply(Long.valueOf(remove.snapshotId));
                    if (remove.resuming) {
                        this.snapshotManager.failedToRestore(apply2, false);
                    } else {
                        this.snapshotManager.failedToCapture(apply2);
                    }
                }
            }
        }
        return z;
    }

    private boolean processInput(String str, Page page) {
        Object obj = null;
        for (SnapshotState snapshotState : this.states) {
            if (!snapshotState.markedChannels.contains(str)) {
                if (obj == null) {
                    obj = this.pagesSerde.serialize(page).capture(this.pagesSerde);
                }
                snapshotState.states.add(obj);
            }
        }
        return false;
    }

    public boolean hasPendingPages() {
        return hasPendingDataPages() || !this.pendingMarkers.isEmpty();
    }

    public boolean hasPendingDataPages() {
        return this.pendingPages != null && this.pendingPages.hasNext();
    }

    private SerializedPage pollPendingPage() {
        if (this.pendingPages == null) {
            return null;
        }
        if (this.pendingPages.hasNext()) {
            Object next = this.pendingPages.next();
            return next instanceof SerializedPage ? (SerializedPage) next : SerializedPage.restoreSerializedPage(next);
        }
        this.pendingPages = null;
        return null;
    }

    public MarkerPage nextMarker() {
        if (this.pendingMarkers.isEmpty()) {
            return null;
        }
        return this.pendingMarkers.remove(0);
    }
}
