package io.prestosql.snapshot;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import io.airlift.log.Logger;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.execution.QueryState;
import io.prestosql.execution.TaskId;
import io.prestosql.snapshot.SnapshotComponentCounter;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.QueryId;
import io.prestosql.spi.StandardErrorCode;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.Timer;
import java.util.TimerTask;
import java.util.function.Consumer;
import java.util.stream.Collectors;

/* loaded from: input_file:io/prestosql/snapshot/QuerySnapshotManager.class */
public class QuerySnapshotManager {
    private static final Logger LOG = Logger.get(QuerySnapshotManager.class);
    private final QueryId queryId;
    private final SnapshotUtils snapshotUtils;
    private final long maxRetry;
    private final long retryTimeout;
    private long retryCount;
    private final Set<TaskId> unfinishedTasks = Sets.newConcurrentHashSet();
    private final Map<Long, SnapshotComponentCounter<TaskId>> captureComponentCounters = Collections.synchronizedMap(new LinkedHashMap());
    private final Map<Long, SnapshotResult> captureResults = Collections.synchronizedMap(new LinkedHashMap());
    private final Map<Long, SnapshotComponentCounter<TaskId>> restoreComponentCounters = Collections.synchronizedMap(new LinkedHashMap());
    private final RestoreResult restoreResult = new RestoreResult();
    private final List<Consumer<RestoreResult>> restoreCompleteListeners = Collections.synchronizedList(new ArrayList());
    private OptionalLong lastTriedId = OptionalLong.empty();
    private Optional<Timer> retryTimer = Optional.empty();
    private Runnable rescheduler = () -> {
    };

    public QuerySnapshotManager(QueryId queryId, SnapshotUtils snapshotUtils, Session session) {
        this.queryId = (QueryId) Objects.requireNonNull(queryId);
        this.snapshotUtils = (SnapshotUtils) Objects.requireNonNull(snapshotUtils);
        if (session == null) {
            this.maxRetry = 0L;
            this.retryTimeout = 0L;
        } else {
            this.maxRetry = SystemSessionProperties.getSnapshotMaxRetries(session);
            this.retryTimeout = SystemSessionProperties.getSnapshotRetryTimeout(session).toMillis();
        }
        snapshotUtils.addQuerySnapshotManager(queryId, this);
    }

    public void setRescheduler(Runnable runnable) {
        this.rescheduler = runnable;
    }

    public boolean isCoordinator() {
        return this.snapshotUtils.isCoordinator();
    }

    public void addNewTask(TaskId taskId) {
        this.unfinishedTasks.add(taskId);
    }

    public OptionalLong getResumeSnapshotId() throws PrestoException {
        if (this.retryCount >= this.maxRetry) {
            throw new PrestoException(StandardErrorCode.TOO_MANY_RESUMES, "Tried to recover query execution for too many times");
        }
        this.retryCount++;
        this.lastTriedId = getResumeSnapshotId(this.lastTriedId);
        if (this.lastTriedId.isPresent()) {
            startSnapshotRestoreTimer();
        }
        this.unfinishedTasks.clear();
        this.captureComponentCounters.clear();
        this.restoreComponentCounters.clear();
        return this.lastTriedId;
    }

    private OptionalLong getResumeSnapshotId(OptionalLong optionalLong) {
        OptionalLong empty = OptionalLong.empty();
        if (this.captureResults.isEmpty()) {
            LOG.debug("Can't find a suitable snapshot to resume for query '%s'", new Object[]{this.queryId.getId()});
            return empty;
        }
        if (!optionalLong.isPresent()) {
            optionalLong = OptionalLong.of(Long.MAX_VALUE);
        }
        synchronized (this.captureResults) {
            ArrayList arrayList = new ArrayList(this.captureResults.entrySet());
            int size = arrayList.size() - 1;
            while (true) {
                if (size < 0) {
                    break;
                }
                long longValue = ((Long) ((Map.Entry) arrayList.get(size)).getKey()).longValue();
                SnapshotResult snapshotResult = (SnapshotResult) ((Map.Entry) arrayList.get(size)).getValue();
                if (longValue == optionalLong.getAsLong()) {
                    this.captureResults.put(Long.valueOf(longValue), SnapshotResult.NA);
                } else if (longValue >= optionalLong.getAsLong()) {
                    continue;
                } else {
                    if (snapshotResult == SnapshotResult.SUCCESSFUL) {
                        empty = OptionalLong.of(longValue);
                        break;
                    }
                    this.captureResults.put(Long.valueOf(longValue), SnapshotResult.NA);
                }
                size--;
            }
            saveQuerySnapshotResult();
        }
        if (empty.isPresent()) {
            LOG.debug("About to resume from snapshot %d for query '%s'", new Object[]{Long.valueOf(empty.getAsLong()), this.queryId.getId()});
        } else {
            LOG.debug("Can't find a suitable snapshot to resume for query '%s'", new Object[]{this.queryId.getId()});
        }
        return empty;
    }

    public void invalidateAllSnapshots() {
        synchronized (this.captureResults) {
            Iterator<Long> it = this.captureResults.keySet().iterator();
            while (it.hasNext()) {
                this.captureResults.put(it.next(), SnapshotResult.NA);
            }
            saveQuerySnapshotResult();
        }
    }

    private void queryRestoreComplete(RestoreResult restoreResult) {
        if (this.retryTimer.isPresent()) {
            synchronized (this) {
                if (this.retryTimer.isPresent()) {
                    this.retryTimer.get().cancel();
                    this.retryTimer = Optional.empty();
                    if (restoreResult.getSnapshotResult() == SnapshotResult.SUCCESSFUL) {
                        this.lastTriedId = OptionalLong.empty();
                    } else {
                        LOG.warn("Failed to restore snapshot for %s, snapshot %d", new Object[]{this.queryId.getId(), Long.valueOf(restoreResult.getSnapshotId())});
                        this.rescheduler.run();
                    }
                }
            }
        }
    }

    private void startSnapshotRestoreTimer() {
        TimerTask timerTask = new TimerTask() { // from class: io.prestosql.snapshot.QuerySnapshotManager.1
            @Override // java.util.TimerTask, java.lang.Runnable
            public void run() {
                synchronized (this) {
                    if (QuerySnapshotManager.this.retryTimer.isPresent()) {
                        QuerySnapshotManager.LOG.warn("Snapshot restore timed out, failed to restore snapshot for %s, snapshot %s", new Object[]{QuerySnapshotManager.this.queryId.getId(), QuerySnapshotManager.this.lastTriedId.toString()});
                        QuerySnapshotManager.this.retryTimer = Optional.empty();
                        QuerySnapshotManager.this.rescheduler.run();
                    }
                }
            }
        };
        Timer timer = new Timer();
        timer.schedule(timerTask, this.retryTimeout);
        synchronized (this) {
            if (this.retryTimer.isPresent()) {
                this.retryTimer.get().cancel();
            }
            this.retryTimer = Optional.of(timer);
        }
    }

    public void doneQuery(QueryState queryState) {
        LOG.debug("query will be removed with queryId = %s,%nstate = %s,%ncaptureComponentCounters = %s,%ncaptureResults = %s,%nrestoreComponentCounters = %s,%nrestoreResult = %s", new Object[]{this.queryId, queryState, this.captureComponentCounters, this.captureResults, this.restoreComponentCounters, this.restoreResult});
        resetForQuery();
        this.snapshotUtils.removeQuerySnapshotManager(this.queryId);
    }

    private void resetForQuery() {
        this.unfinishedTasks.clear();
        this.captureComponentCounters.clear();
        this.captureResults.clear();
        this.restoreComponentCounters.clear();
        this.restoreResult.setSnapshotResult(0L, SnapshotResult.IN_PROGRESS);
        this.restoreCompleteListeners.clear();
    }

    public void addQueryRestoreCompleteListeners(Consumer<RestoreResult> consumer) {
        this.restoreCompleteListeners.add(consumer);
    }

    public void updateQueryCapture(TaskId taskId, Map<Long, SnapshotResult> map) {
        for (Map.Entry<Long, SnapshotResult> entry : map.entrySet()) {
            Long key = entry.getKey();
            SnapshotResult value = entry.getValue();
            if (key.longValue() < 0) {
                Preconditions.checkArgument(value == SnapshotResult.SUCCESSFUL);
                updateCapturedComponents(ImmutableList.of(taskId), false);
            } else {
                updateQueryCapture(taskId, entry.getKey().longValue(), entry.getValue());
            }
        }
    }

    public void updateQueryCapture(TaskId taskId, long j, SnapshotResult snapshotResult) {
        if (snapshotResult == SnapshotResult.FAILED) {
            updateQueryCapture(j, taskId, SnapshotComponentCounter.ComponentState.FAILED);
        } else if (snapshotResult == SnapshotResult.SUCCESSFUL) {
            updateQueryCapture(j, taskId, SnapshotComponentCounter.ComponentState.SUCCESSFUL);
        }
    }

    public void updateQueryRestore(TaskId taskId, Optional<RestoreResult> optional) {
        if (optional.isPresent()) {
            SnapshotResult snapshotResult = optional.get().getSnapshotResult();
            long snapshotId = optional.get().getSnapshotId();
            if (snapshotResult == SnapshotResult.FAILED) {
                LOG.debug("[FATAL] Failed to resume for: " + taskId + ", snapshot " + snapshotId);
                updateQueryRestore(snapshotId, taskId, SnapshotComponentCounter.ComponentState.FAILED);
            } else if (snapshotResult == SnapshotResult.FAILED_FATAL) {
                LOG.debug("Failed to resume for: " + taskId + ", snapshot " + snapshotId);
                updateQueryRestore(snapshotId, taskId, SnapshotComponentCounter.ComponentState.FAILED_FATAL);
            } else if (snapshotResult == SnapshotResult.SUCCESSFUL) {
                updateQueryRestore(snapshotId, taskId, SnapshotComponentCounter.ComponentState.SUCCESSFUL);
            }
        }
    }

    private void saveQuerySnapshotResult() {
        if (this.captureResults.isEmpty()) {
            return;
        }
        try {
            this.snapshotUtils.storeSnapshotResult(this.queryId.getId(), (Map) this.captureResults.entrySet().stream().filter(entry -> {
                return ((SnapshotResult) entry.getValue()).isDone();
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            }, (snapshotResult, snapshotResult2) -> {
                return snapshotResult;
            }, LinkedHashMap::new)));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private void updateQueryCapture(long j, TaskId taskId, SnapshotComponentCounter.ComponentState componentState) {
        SnapshotComponentCounter<TaskId> computeIfAbsent = this.captureComponentCounters.computeIfAbsent(Long.valueOf(j), l -> {
            return new SnapshotComponentCounter(set -> {
                return Boolean.valueOf(set.containsAll(this.unfinishedTasks));
            });
        });
        if (computeIfAbsent.updateComponent(taskId, componentState)) {
            SnapshotResult snapshotResult = computeIfAbsent.getSnapshotResult();
            synchronized (this.captureResults) {
                if (this.captureResults.get(Long.valueOf(j)) != SnapshotResult.NA) {
                    LOG.debug("Finished capturing snapshot %d for task %s", new Object[]{Long.valueOf(j), taskId});
                    if (snapshotResult != this.captureResults.put(Long.valueOf(j), snapshotResult) && snapshotResult.isDone()) {
                        LOG.debug("Finished capturing snapshot %d for query %s. Result is %s.", new Object[]{Long.valueOf(j), this.queryId.getId(), snapshotResult});
                        saveQuerySnapshotResult();
                    }
                }
            }
        }
    }

    private void updateQueryRestore(long j, TaskId taskId, SnapshotComponentCounter.ComponentState componentState) {
        boolean snapshotResult;
        SnapshotComponentCounter<TaskId> computeIfAbsent = this.restoreComponentCounters.computeIfAbsent(Long.valueOf(j), l -> {
            return new SnapshotComponentCounter(set -> {
                return Boolean.valueOf(set.containsAll(this.unfinishedTasks));
            });
        });
        if (computeIfAbsent.updateComponent(taskId, componentState)) {
            LOG.debug("Finished restoring snapshot %d for task %s", new Object[]{Long.valueOf(j), taskId});
            SnapshotResult snapshotResult2 = computeIfAbsent.getSnapshotResult();
            synchronized (this.restoreResult) {
                snapshotResult = this.restoreResult.setSnapshotResult(j, snapshotResult2);
            }
            if (snapshotResult) {
                if (snapshotResult2.isDone()) {
                    LOG.debug("Finished restoring snapshot %d for query %s. Result is %s.", new Object[]{Long.valueOf(j), this.queryId.getId(), snapshotResult2});
                    queryRestoreComplete(this.restoreResult);
                } else if (snapshotResult2 == SnapshotResult.IN_PROGRESS_FAILED || snapshotResult2 == SnapshotResult.IN_PROGRESS_FAILED_FATAL) {
                    LOG.debug("Failed to restore snapshot %d for query %s. Result is %s.", new Object[]{Long.valueOf(j), this.queryId.getId(), snapshotResult2});
                    queryRestoreComplete(this.restoreResult);
                }
            }
        }
    }

    public void updateFinishedQueryComponents(Collection<TaskId> collection) {
        updateCapturedComponents(collection, true);
    }

    public void updateCapturedComponents(Collection<TaskId> collection, boolean z) {
        if (this.unfinishedTasks.removeAll(collection)) {
            if (z) {
                LOG.debug("Some tasks finished for query %s.%n  Finished tasks: %s.%n  Remaining tasks: %s.%n  Snapshot result: %s", new Object[]{this.queryId.getId(), collection, this.unfinishedTasks, this.captureResults});
            } else {
                LOG.debug("Some tasks are fully captured for query %s.%n  Captured tasks: %s.%n  Remaining tasks: %s.%n  Snapshot result: %s", new Object[]{this.queryId.getId(), collection, this.unfinishedTasks, this.captureResults});
            }
            synchronized (this.captureComponentCounters) {
                for (Long l : this.captureComponentCounters.keySet()) {
                    Iterator<TaskId> it = collection.iterator();
                    while (it.hasNext()) {
                        updateQueryCapture(it.next(), ImmutableMap.of(l, SnapshotResult.SUCCESSFUL));
                    }
                }
            }
        }
    }

    public RestoreResult getQuerySnapshotRestoreResult() {
        return this.restoreResult;
    }

    public SnapshotUtils getSnapshotUtils() {
        return this.snapshotUtils;
    }
}
