package io.prestosql.sql.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.plan.ChildReplacer;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/PlanNodeSearcher.class */
public class PlanNodeSearcher {
    private final PlanNode node;
    private final Lookup lookup;
    private Predicate<PlanNode> where = Predicates.alwaysTrue();
    private Predicate<PlanNode> recurseOnlyWhen = Predicates.alwaysTrue();

    public static PlanNodeSearcher searchFrom(PlanNode planNode) {
        return searchFrom(planNode, Lookup.noLookup());
    }

    public static PlanNodeSearcher searchFrom(PlanNode planNode, Lookup lookup) {
        return new PlanNodeSearcher(planNode, lookup);
    }

    private PlanNodeSearcher(PlanNode planNode, Lookup lookup) {
        this.node = (PlanNode) Objects.requireNonNull(planNode, "node is null");
        this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
    }

    public PlanNodeSearcher where(Predicate<PlanNode> predicate) {
        this.where = (Predicate) Objects.requireNonNull(predicate, "where is null");
        return this;
    }

    public PlanNodeSearcher recurseOnlyWhen(Predicate<PlanNode> predicate) {
        this.recurseOnlyWhen = (Predicate) Objects.requireNonNull(predicate, "recurseOnlyWhen is null");
        return this;
    }

    public <T extends PlanNode> Optional<T> findFirst() {
        return findFirstRecursive(this.node);
    }

    private <T extends PlanNode> Optional<T> findFirstRecursive(PlanNode planNode) {
        PlanNode resolve = this.lookup.resolve(planNode);
        if (this.where.test(resolve)) {
            return Optional.of(resolve);
        }
        if (this.recurseOnlyWhen.test(resolve)) {
            Iterator it = resolve.getSources().iterator();
            while (it.hasNext()) {
                Optional<T> findFirstRecursive = findFirstRecursive((PlanNode) it.next());
                if (findFirstRecursive.isPresent()) {
                    return findFirstRecursive;
                }
            }
        }
        return Optional.empty();
    }

    public <T extends PlanNode> Optional<T> findSingle() {
        List<T> findAll = findAll();
        switch (findAll.size()) {
            case 0:
                return Optional.empty();
            case 1:
                return Optional.of(findAll.get(0));
            default:
                throw new IllegalStateException("Multiple nodes found");
        }
    }

    public <T extends PlanNode> List<T> findAll() {
        ImmutableList.Builder<T> builder = ImmutableList.builder();
        findAllRecursive(this.node, builder);
        return builder.build();
    }

    public <T extends PlanNode> T findOnlyElement() {
        return (T) Iterables.getOnlyElement(findAll());
    }

    public <T extends PlanNode> T findOnlyElement(T t) {
        List<T> findAll = findAll();
        return findAll.size() == 0 ? t : (T) Iterables.getOnlyElement(findAll);
    }

    private <T extends PlanNode> void findAllRecursive(PlanNode planNode, ImmutableList.Builder<T> builder) {
        PlanNode resolve = this.lookup.resolve(planNode);
        if (this.where.test(resolve)) {
            builder.add(resolve);
        }
        if (this.recurseOnlyWhen.test(resolve)) {
            Iterator it = resolve.getSources().iterator();
            while (it.hasNext()) {
                findAllRecursive((PlanNode) it.next(), builder);
            }
        }
    }

    public PlanNode removeAll() {
        return removeAllRecursive(this.node);
    }

    private PlanNode removeAllRecursive(PlanNode planNode) {
        PlanNode resolve = this.lookup.resolve(planNode);
        if (!this.where.test(resolve)) {
            return this.recurseOnlyWhen.test(resolve) ? ChildReplacer.replaceChildren(resolve, (List) resolve.getSources().stream().map(this::removeAllRecursive).collect(ImmutableList.toImmutableList())) : resolve;
        }
        Preconditions.checkArgument(resolve.getSources().size() == 1, "Unable to remove plan node as it contains 0 or more than 1 children");
        return (PlanNode) resolve.getSources().get(0);
    }

    public PlanNode removeFirst() {
        return removeFirstRecursive(this.node);
    }

    private PlanNode removeFirstRecursive(PlanNode planNode) {
        PlanNode resolve = this.lookup.resolve(planNode);
        if (this.where.test(resolve)) {
            Preconditions.checkArgument(resolve.getSources().size() == 1, "Unable to remove plan node as it contains 0 or more than 1 children");
            return (PlanNode) resolve.getSources().get(0);
        }
        if (!this.recurseOnlyWhen.test(resolve)) {
            return resolve;
        }
        List sources = resolve.getSources();
        if (sources.isEmpty()) {
            return resolve;
        }
        if (sources.size() == 1) {
            return ChildReplacer.replaceChildren(resolve, ImmutableList.of(removeFirstRecursive((PlanNode) sources.get(0))));
        }
        throw new IllegalArgumentException("Unable to remove first node when a node has multiple children, use removeAll instead");
    }

    public PlanNode replaceAll(PlanNode planNode) {
        return replaceAllRecursive(this.node, planNode);
    }

    private PlanNode replaceAllRecursive(PlanNode planNode, PlanNode planNode2) {
        PlanNode resolve = this.lookup.resolve(planNode);
        return this.where.test(resolve) ? planNode2 : this.recurseOnlyWhen.test(resolve) ? ChildReplacer.replaceChildren(resolve, (List) resolve.getSources().stream().map(planNode3 -> {
            return replaceAllRecursive(planNode3, planNode2);
        }).collect(ImmutableList.toImmutableList())) : resolve;
    }

    public PlanNode replaceFirst(PlanNode planNode) {
        return replaceFirstRecursive(this.node, planNode);
    }

    private PlanNode replaceFirstRecursive(PlanNode planNode, PlanNode planNode2) {
        PlanNode resolve = this.lookup.resolve(planNode);
        if (this.where.test(resolve)) {
            return planNode2;
        }
        List sources = resolve.getSources();
        if (sources.isEmpty()) {
            return resolve;
        }
        if (sources.size() == 1) {
            return ChildReplacer.replaceChildren(resolve, ImmutableList.of(replaceFirstRecursive(resolve, (PlanNode) sources.get(0))));
        }
        throw new IllegalArgumentException("Unable to replace first node when a node has multiple children, use replaceAll instead");
    }

    public boolean matches() {
        return findFirst().isPresent();
    }

    public int count() {
        return findAll().size();
    }
}
