package io.prestosql.snapshot;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import io.airlift.log.Logger;
import io.prestosql.execution.TaskId;
import io.prestosql.operator.Operator;
import io.prestosql.operator.exchange.LocalMergeSourceOperator;
import io.prestosql.snapshot.SnapshotComponentCounter;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;

/* loaded from: input_file:io/prestosql/snapshot/TaskSnapshotManager.class */
public class TaskSnapshotManager {
    private static final Logger LOG = Logger.get(TaskSnapshotManager.class);
    public static final Object NO_STATE = new Object();
    private final TaskId taskId;
    private final SnapshotUtils snapshotUtils;
    private int totalComponents = -1;
    private final Map<Long, SnapshotComponentCounter<SnapshotStateId>> captureComponentCounters = Collections.synchronizedMap(new LinkedHashMap());
    private final Map<Long, SnapshotResult> captureResults = new LinkedHashMap();
    private final Map<Long, SnapshotComponentCounter<SnapshotStateId>> restoreComponentCounters = Collections.synchronizedMap(new LinkedHashMap());
    private final RestoreResult restoreResult = new RestoreResult();

    public TaskSnapshotManager(TaskId taskId, SnapshotUtils snapshotUtils) {
        this.taskId = taskId;
        this.snapshotUtils = snapshotUtils;
    }

    public void storeState(SnapshotStateId snapshotStateId, Object obj) throws Exception {
        this.snapshotUtils.storeState(snapshotStateId, obj);
    }

    public Optional<Object> loadState(SnapshotStateId snapshotStateId) throws Exception {
        Optional<Object> loadState = this.snapshotUtils.loadState(snapshotStateId);
        Map<Long, SnapshotResult> map = null;
        while (!loadState.isPresent()) {
            if (map == null) {
                map = this.snapshotUtils.loadSnapshotResult(snapshotStateId.getTaskId().getQueryId().getId());
            }
            OptionalLong previousSnapshotIdIfComplete = getPreviousSnapshotIdIfComplete(map, snapshotStateId.getSnapshotId());
            if (!previousSnapshotIdIfComplete.isPresent()) {
                return loadState;
            }
            if (previousSnapshotIdIfComplete.getAsLong() == 0) {
                return Optional.of(NO_STATE);
            }
            snapshotStateId = snapshotStateId.withSnapshotId(previousSnapshotIdIfComplete.getAsLong());
            loadState = this.snapshotUtils.loadState(snapshotStateId);
        }
        return loadState;
    }

    public void storeFile(SnapshotStateId snapshotStateId, Path path) throws Exception {
        this.snapshotUtils.storeFile(snapshotStateId, path);
    }

    public Boolean loadFile(SnapshotStateId snapshotStateId, Path path) throws Exception {
        Objects.requireNonNull(path);
        boolean booleanValue = this.snapshotUtils.loadFile(snapshotStateId, path).booleanValue();
        Map<Long, SnapshotResult> map = null;
        while (!booleanValue) {
            if (map == null) {
                map = this.snapshotUtils.loadSnapshotResult(snapshotStateId.getTaskId().getQueryId().getId());
            }
            OptionalLong previousSnapshotIdIfComplete = getPreviousSnapshotIdIfComplete(map, snapshotStateId.getSnapshotId());
            if (!previousSnapshotIdIfComplete.isPresent()) {
                return false;
            }
            if (previousSnapshotIdIfComplete.getAsLong() == 0) {
                return null;
            }
            snapshotStateId = snapshotStateId.withSnapshotId(previousSnapshotIdIfComplete.getAsLong());
            booleanValue = this.snapshotUtils.loadFile(snapshotStateId, path).booleanValue();
        }
        return true;
    }

    private OptionalLong getPreviousSnapshotIdIfComplete(Map<Long, SnapshotResult> map, long j) {
        try {
            ArrayList arrayList = new ArrayList(map.entrySet());
            for (int size = arrayList.size() - 1; size >= 0; size--) {
                long longValue = ((Long) ((Map.Entry) arrayList.get(size)).getKey()).longValue();
                SnapshotResult snapshotResult = (SnapshotResult) ((Map.Entry) arrayList.get(size)).getValue();
                if (longValue < j) {
                    if (snapshotResult == SnapshotResult.SUCCESSFUL) {
                        return OptionalLong.of(longValue);
                    }
                    if (snapshotResult != SnapshotResult.NA) {
                        return OptionalLong.empty();
                    }
                }
            }
            return OptionalLong.of(0L);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void setTotalComponents(int i) {
        this.totalComponents = i;
    }

    public void succeededToCapture(SnapshotStateId snapshotStateId) {
        updateCapture(snapshotStateId, SnapshotComponentCounter.ComponentState.SUCCESSFUL);
    }

    public void failedToCapture(SnapshotStateId snapshotStateId) {
        LOG.debug("Failed to capture snapshot %d for component %s", new Object[]{Long.valueOf(snapshotStateId.getSnapshotId()), snapshotStateId});
        updateCapture(snapshotStateId, SnapshotComponentCounter.ComponentState.FAILED);
    }

    public void succeededToRestore(SnapshotStateId snapshotStateId) {
        updateRestore(snapshotStateId, SnapshotComponentCounter.ComponentState.SUCCESSFUL);
    }

    public void failedToRestore(SnapshotStateId snapshotStateId, boolean z) {
        LOG.debug("Failed (fatal=%b) to restore snapshot %d for component %s", new Object[]{Boolean.valueOf(z), Long.valueOf(snapshotStateId.getSnapshotId()), snapshotStateId});
        if (z) {
            updateRestore(snapshotStateId, SnapshotComponentCounter.ComponentState.FAILED_FATAL);
        } else {
            updateRestore(snapshotStateId, SnapshotComponentCounter.ComponentState.FAILED);
        }
    }

    public Map<Long, SnapshotResult> getSnapshotCaptureResult() {
        ImmutableMap copyOf;
        if (this.totalComponents == 0) {
            return ImmutableMap.of(-1L, SnapshotResult.SUCCESSFUL);
        }
        synchronized (this.captureResults) {
            copyOf = ImmutableMap.copyOf(this.captureResults);
        }
        return copyOf;
    }

    public RestoreResult getSnapshotRestoreResult() {
        return this.restoreResult;
    }

    private void updateCapture(SnapshotStateId snapshotStateId, SnapshotComponentCounter.ComponentState componentState) {
        QuerySnapshotManager querySnapshotManager;
        TaskId taskId = snapshotStateId.getTaskId();
        Preconditions.checkState(this.totalComponents > 0);
        long snapshotId = snapshotStateId.getSnapshotId();
        SnapshotComponentCounter<SnapshotStateId> computeIfAbsent = this.captureComponentCounters.computeIfAbsent(Long.valueOf(snapshotId), l -> {
            return new SnapshotComponentCounter(this.totalComponents);
        });
        if (computeIfAbsent.updateComponent(snapshotStateId, componentState)) {
            SnapshotResult snapshotResult = computeIfAbsent.getSnapshotResult();
            synchronized (this.captureResults) {
                if (snapshotResult != this.captureResults.put(Long.valueOf(snapshotId), snapshotResult) && snapshotResult.isDone()) {
                    if (this.snapshotUtils.isCoordinator() && (querySnapshotManager = this.snapshotUtils.getQuerySnapshotManager(taskId.getQueryId())) != null) {
                        querySnapshotManager.updateQueryCapture(taskId, snapshotId, snapshotResult);
                    }
                    LOG.debug("Finished capturing snapshot %d for task %s. Result is %s.", new Object[]{Long.valueOf(snapshotId), taskId, snapshotResult});
                }
            }
        }
    }

    private void updateRestore(SnapshotStateId snapshotStateId, SnapshotComponentCounter.ComponentState componentState) {
        QuerySnapshotManager querySnapshotManager;
        TaskId taskId = snapshotStateId.getTaskId();
        Preconditions.checkState(this.totalComponents > 0);
        long snapshotId = snapshotStateId.getSnapshotId();
        SnapshotComponentCounter<SnapshotStateId> computeIfAbsent = this.restoreComponentCounters.computeIfAbsent(Long.valueOf(snapshotId), l -> {
            return new SnapshotComponentCounter(this.totalComponents);
        });
        if (computeIfAbsent.updateComponent(snapshotStateId, componentState)) {
            SnapshotResult snapshotResult = computeIfAbsent.getSnapshotResult();
            synchronized (this.restoreResult) {
                if (this.restoreResult.setSnapshotResult(snapshotId, snapshotResult) && snapshotResult.isDone()) {
                    if (this.snapshotUtils.isCoordinator() && (querySnapshotManager = this.snapshotUtils.getQuerySnapshotManager(taskId.getQueryId())) != null) {
                        querySnapshotManager.updateQueryRestore(taskId, Optional.of(this.restoreResult));
                    }
                    LOG.debug("Finished restoring snapshot %d for task %s. Result is %s.", new Object[]{Long.valueOf(snapshotId), taskId.toString(), snapshotResult});
                }
            }
        }
    }

    public void updateFinishedComponents(Collection<Operator> collection) {
        Preconditions.checkState(this.totalComponents > 0);
        synchronized (this.captureComponentCounters) {
            for (Long l : this.captureComponentCounters.keySet()) {
                for (Operator operator : collection) {
                    SnapshotStateId forOperator = SnapshotStateId.forOperator(l.longValue(), operator.getOperatorContext());
                    if (operator instanceof LocalMergeSourceOperator) {
                        forOperator = SnapshotStateId.forTaskComponent(l.longValue(), operator.getOperatorContext().getDriverContext().getPipelineContext().getTaskContext(), ((LocalMergeSourceOperator) operator).getPlanNodeId());
                    }
                    updateCapture(forOperator, SnapshotComponentCounter.ComponentState.SUCCESSFUL);
                }
            }
            this.totalComponents -= collection.size();
        }
    }

    public String toString() {
        return String.format("%s, with total component %d", this.taskId, Integer.valueOf(this.totalComponents));
    }
}
