package io.prestosql.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import io.airlift.log.Logger;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.cost.CostCalculator;
import io.prestosql.cost.CostComparator;
import io.prestosql.cost.CostProvider;
import io.prestosql.cost.PlanCostEstimate;
import io.prestosql.cost.StatsCalculator;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.spi.plan.FilterNode;
import io.prestosql.spi.plan.JoinNode;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.PlanNodeIdAllocator;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.plan.TableScanNode;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.analyzer.FeaturesConfig;
import io.prestosql.sql.planner.EqualityInference;
import io.prestosql.sql.planner.ExpressionDeterminismEvaluator;
import io.prestosql.sql.planner.PlanSymbolAllocator;
import io.prestosql.sql.planner.RuleStatsRecorder;
import io.prestosql.sql.planner.SymbolUtils;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.iterative.IterativeOptimizer;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.iterative.rule.ReorderJoins;
import io.prestosql.sql.planner.optimizations.JoinNodeUtils;
import io.prestosql.sql.planner.optimizations.PlanOptimizer;
import io.prestosql.sql.planner.optimizations.QueryCardinalityUtil;
import io.prestosql.sql.planner.plan.ChildReplacer;
import io.prestosql.sql.planner.plan.InternalPlanVisitor;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.SymbolReference;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/HintedReorderJoins.class */
public class HintedReorderJoins implements PlanOptimizer {
    private final RuleStatsRecorder stats;
    private final StatsCalculator statsCalculator;
    private final CostCalculator costCalculator;
    private final CostComparator costComparator;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/HintedReorderJoins$HintedReorderJoinsRule.class */
    public static class HintedReorderJoinsRule implements Rule<JoinNode> {
        private static final Logger log = Logger.get(HintedReorderJoins.class);
        private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(joinNode -> {
            return !joinNode.getDistributionType().isPresent() && joinNode.getType() == JoinNode.Type.INNER && ExpressionDeterminismEvaluator.isDeterministic((Expression) joinNode.getFilter().map(OriginalExpressionUtils::castToExpression).orElse(BooleanLiteral.TRUE_LITERAL));
        });
        private final CostComparator costComparator;

        /* JADX INFO: Access modifiers changed from: package-private */
        @VisibleForTesting
        /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/HintedReorderJoins$HintedReorderJoinsRule$JoinEnumerationResult.class */
        public static class JoinEnumerationResult {
            public static final JoinEnumerationResult UNKNOWN_COST_RESULT = new JoinEnumerationResult(Optional.empty(), PlanCostEstimate.unknown());
            public static final JoinEnumerationResult INFINITE_COST_RESULT = new JoinEnumerationResult(Optional.empty(), PlanCostEstimate.infinite());
            private final Optional<PlanNode> planNode;
            private final PlanCostEstimate cost;

            private JoinEnumerationResult(Optional<PlanNode> optional, PlanCostEstimate planCostEstimate) {
                this.planNode = (Optional) Objects.requireNonNull(optional, "planNode is null");
                this.cost = (PlanCostEstimate) Objects.requireNonNull(planCostEstimate, "cost is null");
                Preconditions.checkArgument(((planCostEstimate.hasUnknownComponents() || planCostEstimate.equals(PlanCostEstimate.infinite())) && !optional.isPresent()) || (!(planCostEstimate.hasUnknownComponents() && planCostEstimate.equals(PlanCostEstimate.infinite())) && optional.isPresent()), "planNode should be present if and only if cost is known");
            }

            public Optional<PlanNode> getPlanNode() {
                return this.planNode;
            }

            public PlanCostEstimate getCost() {
                return this.cost;
            }

            static JoinEnumerationResult createJoinEnumerationResult(Optional<PlanNode> optional, PlanCostEstimate planCostEstimate) {
                return planCostEstimate.hasUnknownComponents() ? UNKNOWN_COST_RESULT : planCostEstimate.equals(PlanCostEstimate.infinite()) ? INFINITE_COST_RESULT : new JoinEnumerationResult(optional, planCostEstimate);
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @VisibleForTesting
        /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/HintedReorderJoins$HintedReorderJoinsRule$JoinEnumerator.class */
        public static class JoinEnumerator {
            private final Session session;
            private final CostProvider costProvider;
            private final Ordering<JoinEnumerationResult> resultComparator;
            private final PlanNodeIdAllocator idAllocator;
            private final Expression allFilter;
            private final EqualityInference allFilterInference;
            private final Lookup lookup;
            private final Rule.Context context;
            private final Map<Set<PlanNode>, JoinEnumerationResult> memo = new HashMap();

            @VisibleForTesting
            JoinEnumerator(CostComparator costComparator, Expression expression, Rule.Context context) {
                this.context = (Rule.Context) Objects.requireNonNull(context);
                this.session = (Session) Objects.requireNonNull(context.getSession(), "session is null");
                this.costProvider = (CostProvider) Objects.requireNonNull(context.getCostProvider(), "costProvider is null");
                this.resultComparator = costComparator.forSession(this.session).onResultOf(joinEnumerationResult -> {
                    return joinEnumerationResult.cost;
                });
                this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(context.getIdAllocator(), "idAllocator is null");
                this.allFilter = (Expression) Objects.requireNonNull(expression, "filter is null");
                this.allFilterInference = EqualityInference.createEqualityInference(expression);
                this.lookup = (Lookup) Objects.requireNonNull(context.getLookup(), "lookup is null");
            }

            /* JADX INFO: Access modifiers changed from: private */
            public JoinEnumerationResult chooseJoinOrder(LinkedHashSet<PlanNode> linkedHashSet, List<Symbol> list, String str) {
                this.context.checkTimeoutNotExhausted();
                Set<PlanNode> copyOf = ImmutableSet.copyOf(linkedHashSet);
                JoinEnumerationResult joinEnumerationResult = this.memo.get(copyOf);
                if (joinEnumerationResult == null) {
                    Preconditions.checkState(linkedHashSet.size() > 1, "sources size is less than or equal to one");
                    ImmutableList.Builder builder = ImmutableList.builder();
                    Iterator<Set<Integer>> it = generatePartitions(linkedHashSet.size()).iterator();
                    while (it.hasNext()) {
                        JoinEnumerationResult createJoinAccordingToPartitioning = createJoinAccordingToPartitioning(linkedHashSet, list, it.next(), str);
                        if (createJoinAccordingToPartitioning.planNode.isPresent() && str != null && !"".equals(str)) {
                            StringBuilder sb = new StringBuilder();
                            ((PlanNode) createJoinAccordingToPartitioning.planNode.get()).accept(new TableNameExtractor(this.lookup), sb);
                            if (TableNameExtractor.startsWith(sb.toString(), str)) {
                                this.memo.put(copyOf, createJoinAccordingToPartitioning);
                                return createJoinAccordingToPartitioning;
                            }
                        }
                        if (createJoinAccordingToPartitioning.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                            this.memo.put(copyOf, createJoinAccordingToPartitioning);
                            return createJoinAccordingToPartitioning;
                        }
                        if (!createJoinAccordingToPartitioning.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                            builder.add(createJoinAccordingToPartitioning);
                        }
                    }
                    ImmutableList<JoinEnumerationResult> build = builder.build();
                    if (build.isEmpty()) {
                        this.memo.put(copyOf, JoinEnumerationResult.INFINITE_COST_RESULT);
                        return JoinEnumerationResult.INFINITE_COST_RESULT;
                    }
                    if (build.size() > 1) {
                        for (JoinEnumerationResult joinEnumerationResult2 : build) {
                            StringBuilder sb2 = new StringBuilder();
                            ((PlanNode) joinEnumerationResult2.planNode.get()).accept(new TableNameExtractor(this.lookup), sb2);
                            if (TableNameExtractor.startsWith(sb2.toString(), str)) {
                                this.memo.put(copyOf, joinEnumerationResult2);
                                return joinEnumerationResult2;
                            }
                        }
                    }
                    joinEnumerationResult = (JoinEnumerationResult) this.resultComparator.min(build);
                    this.memo.put(copyOf, joinEnumerationResult);
                }
                joinEnumerationResult.planNode.ifPresent(planNode -> {
                    HintedReorderJoinsRule.log.debug("Least cost join was: %s", new Object[]{planNode});
                });
                return joinEnumerationResult;
            }

            @VisibleForTesting
            static Set<Set<Integer>> generatePartitions(int i) {
                Preconditions.checkArgument(i > 1, "totalNodes must be greater than 1");
                Set set = (Set) IntStream.range(0, i).boxed().collect(ImmutableSet.toImmutableSet());
                return (Set) Sets.powerSet(set).stream().filter(set2 -> {
                    return set2.contains(0);
                }).filter(set3 -> {
                    return set3.size() < set.size();
                }).collect(ImmutableSet.toImmutableSet());
            }

            @VisibleForTesting
            JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet<PlanNode> linkedHashSet, List<Symbol> list, Set<Integer> set) {
                return createJoinAccordingToPartitioning(linkedHashSet, list, set, "");
            }

            @VisibleForTesting
            JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet<PlanNode> linkedHashSet, List<Symbol> list, Set<Integer> set, String str) {
                ImmutableList copyOf = ImmutableList.copyOf(linkedHashSet);
                Stream<Integer> stream = set.stream();
                copyOf.getClass();
                LinkedHashSet<PlanNode> linkedHashSet2 = (LinkedHashSet) stream.map((v1) -> {
                    return r1.get(v1);
                }).collect(Collectors.toCollection(LinkedHashSet::new));
                return createJoin(linkedHashSet2, (LinkedHashSet) linkedHashSet.stream().filter(planNode -> {
                    return !linkedHashSet2.contains(planNode);
                }).collect(Collectors.toCollection(LinkedHashSet::new)), list, str);
            }

            private JoinEnumerationResult createJoin(LinkedHashSet<PlanNode> linkedHashSet, LinkedHashSet<PlanNode> linkedHashSet2, List<Symbol> list, String str) {
                Set<Symbol> set = (Set) linkedHashSet.stream().flatMap(planNode -> {
                    return planNode.getOutputSymbols().stream();
                }).collect(ImmutableSet.toImmutableSet());
                Set<Symbol> set2 = (Set) linkedHashSet2.stream().flatMap(planNode2 -> {
                    return planNode2.getOutputSymbols().stream();
                }).collect(ImmutableSet.toImmutableSet());
                List<Expression> joinPredicates = getJoinPredicates(set, set2);
                List list2 = (List) joinPredicates.stream().filter(JoinEnumerator::isJoinEqualityCondition).map(expression -> {
                    return toEquiJoinClause((ComparisonExpression) expression, set);
                }).collect(ImmutableList.toImmutableList());
                if (list2.isEmpty()) {
                    return JoinEnumerationResult.INFINITE_COST_RESULT;
                }
                List list3 = (List) joinPredicates.stream().filter(expression2 -> {
                    return !isJoinEqualityCondition(expression2);
                }).collect(ImmutableList.toImmutableList());
                ImmutableSet build = ImmutableSet.builder().addAll(list).addAll(SymbolsExtractor.extractUnique(joinPredicates)).build();
                Stream stream = build.stream();
                set.getClass();
                JoinEnumerationResult joinSource = getJoinSource(linkedHashSet, (List) stream.filter((v1) -> {
                    return r3.contains(v1);
                }).collect(ImmutableList.toImmutableList()), str);
                if (joinSource.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                    return JoinEnumerationResult.UNKNOWN_COST_RESULT;
                }
                if (joinSource.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                    return JoinEnumerationResult.INFINITE_COST_RESULT;
                }
                PlanNode planNode3 = (PlanNode) joinSource.planNode.orElseThrow(() -> {
                    return new VerifyException("Plan node is not present");
                });
                Stream stream2 = build.stream();
                set2.getClass();
                JoinEnumerationResult joinSource2 = getJoinSource(linkedHashSet2, (List) stream2.filter((v1) -> {
                    return r3.contains(v1);
                }).collect(ImmutableList.toImmutableList()), str);
                if (joinSource2.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                    return JoinEnumerationResult.UNKNOWN_COST_RESULT;
                }
                if (joinSource2.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                    return JoinEnumerationResult.INFINITE_COST_RESULT;
                }
                PlanNode planNode4 = (PlanNode) joinSource2.planNode.orElseThrow(() -> {
                    return new VerifyException("Plan node is not present");
                });
                Stream concat = Stream.concat(planNode3.getOutputSymbols().stream(), planNode4.getOutputSymbols().stream());
                list.getClass();
                return setJoinNodeProperties(new JoinNode(this.idAllocator.getNextId(), JoinNode.Type.INNER, planNode3, planNode4, list2, (List) concat.filter((v1) -> {
                    return r1.contains(v1);
                }).collect(ImmutableList.toImmutableList()), list3.isEmpty() ? Optional.empty() : Optional.of(ExpressionUtils.and(list3)).map(OriginalExpressionUtils::castToRowExpression), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of()), str);
            }

            private List<Expression> getJoinPredicates(Set<Symbol> set, Set<Symbol> set2) {
                ImmutableList.Builder builder = ImmutableList.builder();
                Stream filter = Streams.stream(EqualityInference.nonInferrableConjuncts(this.allFilter)).map(expression -> {
                    return this.allFilterInference.rewriteExpression(expression, symbol -> {
                        return set.contains(symbol) || set2.contains(symbol);
                    });
                }).filter((v0) -> {
                    return Objects.nonNull(v0);
                }).filter(expression2 -> {
                    EqualityInference equalityInference = this.allFilterInference;
                    set.getClass();
                    return equalityInference.rewriteExpression(expression2, (v1) -> {
                        return r2.contains(v1);
                    }) == null;
                }).filter(expression3 -> {
                    EqualityInference equalityInference = this.allFilterInference;
                    set2.getClass();
                    return equalityInference.rewriteExpression(expression3, (v1) -> {
                        return r2.contains(v1);
                    }) == null;
                });
                builder.getClass();
                filter.forEach((v1) -> {
                    r1.add(v1);
                });
                builder.addAll(EqualityInference.createEqualityInference((Expression[]) this.allFilterInference.generateEqualitiesPartitionedBy(symbol -> {
                    return set.contains(symbol) || set2.contains(symbol);
                }).getScopeEqualities().toArray(new Expression[0])).generateEqualitiesPartitionedBy(Predicates.in(set)).getScopeStraddlingEqualities());
                return builder.build();
            }

            private JoinEnumerationResult getJoinSource(LinkedHashSet<PlanNode> linkedHashSet, List<Symbol> list, String str) {
                if (linkedHashSet.size() != 1) {
                    return chooseJoinOrder(linkedHashSet, list, str);
                }
                PlanNode planNode = (PlanNode) Iterables.getOnlyElement(linkedHashSet);
                ImmutableList.Builder builder = ImmutableList.builder();
                EqualityInference equalityInference = this.allFilterInference;
                list.getClass();
                builder.addAll(equalityInference.generateEqualitiesPartitionedBy((v1) -> {
                    return r2.contains(v1);
                }).getScopeEqualities());
                Stream filter = Streams.stream(EqualityInference.nonInferrableConjuncts(this.allFilter)).map(expression -> {
                    EqualityInference equalityInference2 = this.allFilterInference;
                    list.getClass();
                    return equalityInference2.rewriteExpression(expression, (v1) -> {
                        return r2.contains(v1);
                    });
                }).filter((v0) -> {
                    return Objects.nonNull(v0);
                });
                builder.getClass();
                filter.forEach((v1) -> {
                    r1.add(v1);
                });
                Expression combineConjuncts = ExpressionUtils.combineConjuncts((Collection<Expression>) builder.build());
                if (!BooleanLiteral.TRUE_LITERAL.equals(combineConjuncts)) {
                    planNode = new FilterNode(this.idAllocator.getNextId(), planNode, OriginalExpressionUtils.castToRowExpression(combineConjuncts));
                }
                return createJoinEnumerationResult(planNode);
            }

            private static boolean isJoinEqualityCondition(Expression expression) {
                return (expression instanceof ComparisonExpression) && ((ComparisonExpression) expression).getOperator() == ComparisonExpression.Operator.EQUAL && (((ComparisonExpression) expression).getLeft() instanceof SymbolReference) && (((ComparisonExpression) expression).getRight() instanceof SymbolReference);
            }

            /* JADX INFO: Access modifiers changed from: private */
            public static JoinNode.EquiJoinClause toEquiJoinClause(ComparisonExpression comparisonExpression, Set<Symbol> set) {
                Symbol from = SymbolUtils.from(comparisonExpression.getLeft());
                JoinNode.EquiJoinClause equiJoinClause = new JoinNode.EquiJoinClause(from, SymbolUtils.from(comparisonExpression.getRight()));
                return set.contains(from) ? equiJoinClause : equiJoinClause.flip();
            }

            private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode, String str) {
                if (QueryCardinalityUtil.isAtMostScalar(joinNode.getRight(), this.lookup)) {
                    return createJoinEnumerationResult(joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED));
                }
                if (QueryCardinalityUtil.isAtMostScalar(joinNode.getLeft(), this.lookup)) {
                    return createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(JoinNode.DistributionType.REPLICATED));
                }
                List<JoinEnumerationResult> possibleJoinNodes = getPossibleJoinNodes(joinNode, SystemSessionProperties.getJoinDistributionType(this.session));
                Verify.verify(!possibleJoinNodes.isEmpty(), "possibleJoinNodes is empty", new Object[0]);
                Stream<JoinEnumerationResult> stream = possibleJoinNodes.stream();
                JoinEnumerationResult joinEnumerationResult = JoinEnumerationResult.UNKNOWN_COST_RESULT;
                joinEnumerationResult.getClass();
                if (stream.anyMatch((v1) -> {
                    return r1.equals(v1);
                })) {
                    return JoinEnumerationResult.UNKNOWN_COST_RESULT;
                }
                for (JoinEnumerationResult joinEnumerationResult2 : possibleJoinNodes) {
                    StringBuilder sb = new StringBuilder();
                    ((PlanNode) joinEnumerationResult2.planNode.get()).accept(new TableNameExtractor(this.lookup), sb);
                    if (TableNameExtractor.startsWith(sb.toString(), str)) {
                        return joinEnumerationResult2;
                    }
                }
                return (JoinEnumerationResult) this.resultComparator.min(possibleJoinNodes);
            }

            private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, FeaturesConfig.JoinDistributionType joinDistributionType) {
                Preconditions.checkArgument(joinNode.getType() == JoinNode.Type.INNER, "unexpected join node type: %s", joinNode.getType());
                if (joinNode.isCrossJoin()) {
                    return getPossibleJoinNodes(joinNode, JoinNode.DistributionType.REPLICATED);
                }
                switch (joinDistributionType) {
                    case PARTITIONED:
                        return getPossibleJoinNodes(joinNode, JoinNode.DistributionType.PARTITIONED);
                    case BROADCAST:
                        return getPossibleJoinNodes(joinNode, JoinNode.DistributionType.REPLICATED);
                    case AUTOMATIC:
                        ImmutableList.Builder builder = ImmutableList.builder();
                        builder.addAll(getPossibleJoinNodes(joinNode, JoinNode.DistributionType.PARTITIONED));
                        if (DetermineJoinDistributionType.canReplicate(joinNode, this.context)) {
                            builder.addAll(getPossibleJoinNodes(joinNode, JoinNode.DistributionType.REPLICATED));
                        }
                        return builder.build();
                    default:
                        throw new IllegalArgumentException("unexpected join distribution type: " + joinDistributionType);
                }
            }

            private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, JoinNode.DistributionType distributionType) {
                return ImmutableList.of(createJoinEnumerationResult(joinNode.withDistributionType(distributionType)), createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(distributionType)));
            }

            private JoinEnumerationResult createJoinEnumerationResult(PlanNode planNode) {
                return JoinEnumerationResult.createJoinEnumerationResult(Optional.of(planNode), this.costProvider.getCost(planNode));
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @VisibleForTesting
        /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/HintedReorderJoins$HintedReorderJoinsRule$MultiJoinNode.class */
        public static class MultiJoinNode {
            private final LinkedHashSet<PlanNode> sources;
            private final Expression filter;
            private final List<Symbol> outputSymbols;

            /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/HintedReorderJoins$HintedReorderJoinsRule$MultiJoinNode$Builder.class */
            static class Builder {
                private List<PlanNode> sources;
                private Expression filter;
                private List<Symbol> outputSymbols;

                Builder() {
                }

                public Builder setSources(PlanNode... planNodeArr) {
                    this.sources = ImmutableList.copyOf(planNodeArr);
                    return this;
                }

                public Builder setFilter(Expression expression) {
                    this.filter = expression;
                    return this;
                }

                public Builder setOutputSymbols(Symbol... symbolArr) {
                    this.outputSymbols = ImmutableList.copyOf(symbolArr);
                    return this;
                }

                public MultiJoinNode build() {
                    return new MultiJoinNode(new LinkedHashSet(this.sources), this.filter, this.outputSymbols);
                }
            }

            /* JADX INFO: Access modifiers changed from: private */
            /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/HintedReorderJoins$HintedReorderJoinsRule$MultiJoinNode$JoinNodeFlattener.class */
            public static class JoinNodeFlattener {
                private final LinkedHashSet<PlanNode> sources = new LinkedHashSet<>();
                private final List<Expression> filters = new ArrayList();
                private final List<Symbol> outputSymbols;
                private final Lookup lookup;

                JoinNodeFlattener(JoinNode joinNode, Lookup lookup, int i) {
                    Objects.requireNonNull(joinNode, "node is null");
                    Preconditions.checkState(joinNode.getType() == JoinNode.Type.INNER, "join type must be INNER");
                    this.outputSymbols = joinNode.getOutputSymbols();
                    this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
                    flattenNode(joinNode, i);
                }

                private void flattenNode(PlanNode planNode, int i) {
                    JoinNode resolve = this.lookup.resolve(planNode);
                    if (!(resolve instanceof JoinNode) || this.sources.size() > i - 2) {
                        this.sources.add(planNode);
                        return;
                    }
                    JoinNode joinNode = resolve;
                    if (joinNode.getType() != JoinNode.Type.INNER || !ExpressionDeterminismEvaluator.isDeterministic((Expression) joinNode.getFilter().map(OriginalExpressionUtils::castToExpression).orElse(BooleanLiteral.TRUE_LITERAL)) || joinNode.getDistributionType().isPresent()) {
                        this.sources.add(planNode);
                        return;
                    }
                    flattenNode(joinNode.getLeft(), i - 1);
                    flattenNode(joinNode.getRight(), i);
                    Stream map = joinNode.getCriteria().stream().map(JoinNodeUtils::toExpression);
                    List<Expression> list = this.filters;
                    list.getClass();
                    map.forEach((v1) -> {
                        r1.add(v1);
                    });
                    Optional map2 = joinNode.getFilter().map(OriginalExpressionUtils::castToExpression);
                    List<Expression> list2 = this.filters;
                    list2.getClass();
                    map2.ifPresent((v1) -> {
                        r1.add(v1);
                    });
                }

                MultiJoinNode toMultiJoinNode() {
                    return new MultiJoinNode(this.sources, ExpressionUtils.and(this.filters), this.outputSymbols);
                }
            }

            public MultiJoinNode(LinkedHashSet<PlanNode> linkedHashSet, Expression expression, List<Symbol> list) {
                Objects.requireNonNull(linkedHashSet, "sources is null");
                Preconditions.checkArgument(linkedHashSet.size() > 1, "sources size is <= 1");
                Objects.requireNonNull(expression, "filter is null");
                Objects.requireNonNull(list, "outputSymbols is null");
                this.sources = linkedHashSet;
                this.filter = expression;
                this.outputSymbols = ImmutableList.copyOf(list);
                Preconditions.checkArgument(((List) linkedHashSet.stream().flatMap(planNode -> {
                    return planNode.getOutputSymbols().stream();
                }).collect(ImmutableList.toImmutableList())).containsAll(list), "inputs do not contain all output symbols");
            }

            public Expression getFilter() {
                return this.filter;
            }

            public LinkedHashSet<PlanNode> getSources() {
                return this.sources;
            }

            public List<Symbol> getOutputSymbols() {
                return this.outputSymbols;
            }

            public static Builder builder() {
                return new Builder();
            }

            public int hashCode() {
                return Objects.hash(this.sources, ImmutableSet.copyOf(ExpressionUtils.extractConjuncts(this.filter)), this.outputSymbols);
            }

            public boolean equals(Object obj) {
                if (!(obj instanceof MultiJoinNode)) {
                    return false;
                }
                MultiJoinNode multiJoinNode = (MultiJoinNode) obj;
                return this.sources.equals(multiJoinNode.sources) && ImmutableSet.copyOf(ExpressionUtils.extractConjuncts(this.filter)).equals(ImmutableSet.copyOf(ExpressionUtils.extractConjuncts(multiJoinNode.filter))) && this.outputSymbols.equals(multiJoinNode.outputSymbols);
            }

            static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int i) {
                return new JoinNodeFlattener(joinNode, lookup, i + 1).toMultiJoinNode();
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/HintedReorderJoins$HintedReorderJoinsRule$TableNameExtractor.class */
        public static class TableNameExtractor extends InternalPlanVisitor<Void, StringBuilder> {
            private final Lookup lookup;
            private String pattern;

            private TableNameExtractor(Lookup lookup) {
                this.pattern = "";
                this.lookup = lookup;
            }

            public Void visitTableScan(TableScanNode tableScanNode, StringBuilder sb) {
                sb.append(tableScanNode.getTable().getConnectorHandle().getTableName().toString());
                return null;
            }

            public Void visitPlan(PlanNode planNode, StringBuilder sb) {
                Iterator it = this.lookup.resolve(planNode).getSources().iterator();
                while (it.hasNext()) {
                    this.lookup.resolve((PlanNode) it.next()).accept(this, sb);
                }
                return null;
            }

            public Void visitJoin(JoinNode joinNode, StringBuilder sb) {
                PlanNode resolve = this.lookup.resolve(joinNode.getLeft());
                PlanNode resolve2 = this.lookup.resolve(joinNode.getRight());
                sb.append('(');
                resolve.accept(this, sb);
                sb.append(',');
                resolve2.accept(this, sb);
                sb.append(')');
                return null;
            }

            public static boolean startsWith(String str, String str2) {
                return str2.contains(str);
            }
        }

        public HintedReorderJoinsRule(CostComparator costComparator) {
            this.costComparator = (CostComparator) Objects.requireNonNull(costComparator, "costComparator is null");
        }

        @Override // io.prestosql.sql.planner.iterative.Rule
        public Pattern<JoinNode> getPattern() {
            return PATTERN;
        }

        @Override // io.prestosql.sql.planner.iterative.Rule
        public boolean isEnabled(Session session) {
            String joinOrder;
            return (SystemSessionProperties.getJoinReorderingStrategy(session) != FeaturesConfig.JoinReorderingStrategy.NONE || (joinOrder = SystemSessionProperties.getJoinOrder(session)) == null || "".equals(joinOrder)) ? false : true;
        }

        @Override // io.prestosql.sql.planner.iterative.Rule
        public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
            MultiJoinNode multiJoinNode = MultiJoinNode.toMultiJoinNode(joinNode, context.getLookup(), SystemSessionProperties.getMaxReorderedJoins(context.getSession()));
            JoinEnumerator joinEnumerator = new JoinEnumerator(this.costComparator, multiJoinNode.getFilter(), context);
            String str = "";
            String joinOrder = SystemSessionProperties.getJoinOrder(context.getSession());
            if (joinOrder != null && !"".equals(joinOrder.trim())) {
                str = joinOrder.trim();
            }
            Lookup lookup = context.getLookup();
            Iterator<PlanNode> it = multiJoinNode.getSources().iterator();
            while (it.hasNext()) {
                PlanNode next = it.next();
                lookup.resolve(next).accept(new TableNameExtractor(context.getLookup()), new StringBuilder());
            }
            JoinEnumerationResult chooseJoinOrder = joinEnumerator.chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols(), str);
            return !chooseJoinOrder.getPlanNode().isPresent() ? Rule.Result.empty() : Rule.Result.ofPlanNode(chooseJoinOrder.getPlanNode().get());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/HintedReorderJoins$JoinNodeCandidatesExtractor.class */
    public static class JoinNodeCandidatesExtractor {
        private final LinkedHashSet<PlanNode> sources = new LinkedHashSet<>();
        private final List<Expression> filters = new ArrayList();
        private final List<Symbol> outputSymbols;

        JoinNodeCandidatesExtractor(JoinNode joinNode, int i) {
            Objects.requireNonNull(joinNode, "node is null");
            Preconditions.checkState(joinNode.getType() == JoinNode.Type.INNER, "join type must be INNER");
            this.outputSymbols = joinNode.getOutputSymbols();
            flattenNode(joinNode, i);
        }

        static ReorderJoins.MultiJoinNode toMultiJoinNode(JoinNode joinNode, int i) {
            return new JoinNodeCandidatesExtractor(joinNode, i + 1).toMultiJoinNode();
        }

        private void flattenNode(PlanNode planNode, int i) {
            if (!(planNode instanceof JoinNode) || this.sources.size() > i - 2) {
                this.sources.add(planNode);
                return;
            }
            JoinNode joinNode = (JoinNode) planNode;
            if (joinNode.getType() != JoinNode.Type.INNER || !ExpressionDeterminismEvaluator.isDeterministic((Expression) joinNode.getFilter().map(OriginalExpressionUtils::castToExpression).orElse(BooleanLiteral.TRUE_LITERAL)) || joinNode.getDistributionType().isPresent()) {
                this.sources.add(planNode);
                return;
            }
            flattenNode(joinNode.getLeft(), i - 1);
            flattenNode(joinNode.getRight(), i);
            Stream map = joinNode.getCriteria().stream().map(JoinNodeUtils::toExpression);
            List<Expression> list = this.filters;
            list.getClass();
            map.forEach((v1) -> {
                r1.add(v1);
            });
        }

        ReorderJoins.MultiJoinNode toMultiJoinNode() {
            return new ReorderJoins.MultiJoinNode(this.sources, ExpressionUtils.and(this.filters), this.outputSymbols);
        }
    }

    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/HintedReorderJoins$PreRuleOptimizer.class */
    private class PreRuleOptimizer extends SimplePlanRewriter<Void> {
        private final Session session;
        private final TypeProvider types;
        private final PlanSymbolAllocator planSymbolAllocator;
        private final PlanNodeIdAllocator idAllocator;
        private final WarningCollector warningCollector;
        private Set<PlanNode> optimizableSources;

        private PreRuleOptimizer(Session session, TypeProvider typeProvider, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
            this.session = session;
            this.types = typeProvider;
            this.planSymbolAllocator = planSymbolAllocator;
            this.idAllocator = planNodeIdAllocator;
            this.warningCollector = warningCollector;
        }

        @Override // io.prestosql.sql.planner.plan.SimplePlanRewriter
        public PlanNode visitPlan(PlanNode planNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            return (this.optimizableSources == null || !this.optimizableSources.contains(planNode)) ? super.visitPlan(planNode, (SimplePlanRewriter.RewriteContext) rewriteContext) : rewrite(planNode);
        }

        public PlanNode visitJoin(JoinNode joinNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            if (this.optimizableSources != null) {
                return visitPlan((PlanNode) joinNode, rewriteContext);
            }
            this.optimizableSources = JoinNodeCandidatesExtractor.toMultiJoinNode(joinNode, SystemSessionProperties.getMaxReorderedJoins(this.session)).getSources();
            return rewrite(ChildReplacer.replaceChildren(joinNode, ImmutableList.of(rewriteContext.rewrite(joinNode.getLeft(), rewriteContext.get()), rewriteContext.rewrite(joinNode.getRight(), rewriteContext.get()))));
        }

        private PlanNode rewrite(PlanNode planNode) {
            return new IterativeOptimizer(HintedReorderJoins.this.stats, HintedReorderJoins.this.statsCalculator, HintedReorderJoins.this.costCalculator, ImmutableSet.of(new HintedReorderJoinsRule(HintedReorderJoins.this.costComparator))).optimize(planNode, this.session, this.types, this.planSymbolAllocator, this.idAllocator, this.warningCollector);
        }
    }

    public HintedReorderJoins(RuleStatsRecorder ruleStatsRecorder, StatsCalculator statsCalculator, CostCalculator costCalculator, CostComparator costComparator) {
        this.stats = ruleStatsRecorder;
        this.statsCalculator = statsCalculator;
        this.costCalculator = costCalculator;
        this.costComparator = costComparator;
    }

    @Override // io.prestosql.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        return SimplePlanRewriter.rewriteWith(new PreRuleOptimizer(session, typeProvider, planSymbolAllocator, planNodeIdAllocator, warningCollector), planNode);
    }
}
