package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.cost.CachingStatsProvider;
import io.prestosql.cost.PlanNodeStatsEstimate;
import io.prestosql.cost.StatsCalculator;
import io.prestosql.cost.StatsProvider;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.expressions.LogicalRowExpressions;
import io.prestosql.expressions.RowExpressionRewriter;
import io.prestosql.expressions.RowExpressionTreeRewriter;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.Constraint;
import io.prestosql.spi.plan.FilterNode;
import io.prestosql.spi.plan.JoinNode;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.PlanNodeIdAllocator;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.TableScanNode;
import io.prestosql.spi.relation.RowExpression;
import io.prestosql.spi.relation.SpecialForm;
import io.prestosql.spi.relation.VariableReferenceExpression;
import io.prestosql.spi.statistics.Estimate;
import io.prestosql.sql.DynamicFilters;
import io.prestosql.sql.planner.PlanSymbolAllocator;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.optimizations.PlanOptimizer;
import io.prestosql.sql.planner.plan.ChildReplacer;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.InternalPlanVisitor;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.relational.FunctionResolution;
import io.prestosql.sql.relational.RowExpressionDeterminismEvaluator;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.class */
public class RemoveUnsupportedDynamicFilters implements PlanOptimizer {
    private static final double DEFAULT_GENERATE_SELECTIVITY_THRESHOLD = 0.5d;
    private static final double DEFAULT_REMOVE_SELECTIVITY_THRESHOLD = 0.01d;
    private final Metadata metadata;
    private final StatsCalculator statsCalculator;
    private StatsProvider statsProvider;
    private Set<String> removedDynamicFilterIds;
    private final LogicalRowExpressions logicalRowExpressions;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters$PlanWithConsumedDynamicFilters.class */
    public static class PlanWithConsumedDynamicFilters {
        private final PlanNode node;
        private final Set<String> consumedDynamicFilterIds;

        PlanWithConsumedDynamicFilters(PlanNode planNode, Set<String> set) {
            this.node = planNode;
            this.consumedDynamicFilterIds = ImmutableSet.copyOf(set);
        }

        PlanNode getNode() {
            return this.node;
        }

        Set<String> getConsumedDynamicFilterIds() {
            return this.consumedDynamicFilterIds;
        }
    }

    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters$RemoveFilterVisitor.class */
    private static class RemoveFilterVisitor extends SimplePlanRewriter<Void> {
        private final Set<String> removedDynamicFilterIds;
        private final LogicalRowExpressions logicalRowExpressions;

        public RemoveFilterVisitor(Set<String> set, Metadata metadata) {
            this.removedDynamicFilterIds = set;
            this.logicalRowExpressions = new LogicalRowExpressions(new RowExpressionDeterminismEvaluator(metadata), new FunctionResolution(metadata.getFunctionAndTypeManager()), metadata.getFunctionAndTypeManager());
        }

        public PlanNode visitFilter(FilterNode filterNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            PlanNode rewrite = rewriteContext.rewrite(filterNode.getSource());
            RowExpression predicate = filterNode.getPredicate();
            return new FilterNode(filterNode.getId(), rewrite, rewrite instanceof TableScanNode ? this.logicalRowExpressions.combineConjuncts((Collection) LogicalRowExpressions.extractConjuncts(predicate).stream().filter(rowExpression -> {
                return ((Boolean) DynamicFilters.getDescriptor(rowExpression).map(descriptor -> {
                    return Boolean.valueOf(!this.removedDynamicFilterIds.contains(descriptor.getId()));
                }).orElse(true)).booleanValue();
            }).collect(ImmutableList.toImmutableList())) : predicate);
        }

        public PlanNode visitJoin(JoinNode joinNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            PlanNode rewrite = rewriteContext.rewrite(joinNode.getLeft());
            PlanNode rewrite2 = rewriteContext.rewrite(joinNode.getRight());
            Map map = (Map) joinNode.getDynamicFilters().entrySet().stream().filter(entry -> {
                return !this.removedDynamicFilterIds.contains(entry.getKey());
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            }));
            return (rewrite == joinNode.getLeft() && rewrite2 == joinNode.getRight() && map.equals(joinNode.getDynamicFilters())) ? joinNode : new JoinNode(joinNode.getId(), joinNode.getType(), rewrite, rewrite2, joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), map);
        }

        @Override // io.prestosql.sql.planner.plan.InternalPlanVisitor
        public PlanNode visitSemiJoin(SemiJoinNode semiJoinNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            PlanNode rewrite = rewriteContext.rewrite(semiJoinNode.getFilteringSource());
            PlanNode rewrite2 = rewriteContext.rewrite(semiJoinNode.getSource());
            Optional<String> filter = semiJoinNode.getDynamicFilterId().filter(str -> {
                return !this.removedDynamicFilterIds.contains(str);
            });
            return (rewrite == semiJoinNode.getFilteringSource() && rewrite2 == semiJoinNode.getSource() && filter.equals(semiJoinNode.getDynamicFilterId())) ? semiJoinNode : new SemiJoinNode(semiJoinNode.getId(), semiJoinNode.getSource(), semiJoinNode.getFilteringSource(), semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol(), semiJoinNode.getSemiJoinOutput(), semiJoinNode.getSourceHashSymbol(), semiJoinNode.getFilteringSourceHashSymbol(), semiJoinNode.getDistributionType(), filter);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters$Rewriter.class */
    public class Rewriter extends InternalPlanVisitor<PlanWithConsumedDynamicFilters, Set<String>> {
        private final Metadata metadata;
        private final Session session;
        private final Set<String> removedDynamicFilterIds;

        public Rewriter(Session session, Metadata metadata, Set<String> set) {
            this.session = session;
            this.metadata = metadata;
            this.removedDynamicFilterIds = set;
        }

        public PlanWithConsumedDynamicFilters visitPlan(PlanNode planNode, Set<String> set) {
            List list = (List) planNode.getSources().stream().map(planNode2 -> {
                return (PlanWithConsumedDynamicFilters) planNode2.accept(this, set);
            }).collect(ImmutableList.toImmutableList());
            return new PlanWithConsumedDynamicFilters(ChildReplacer.replaceChildren(planNode, (List) list.stream().map((v0) -> {
                return v0.getNode();
            }).collect(Collectors.toList())), (Set) list.stream().map((v0) -> {
                return v0.getConsumedDynamicFilterIds();
            }).flatMap((v0) -> {
                return v0.stream();
            }).collect(ImmutableSet.toImmutableSet()));
        }

        public PlanWithConsumedDynamicFilters visitJoin(JoinNode joinNode, Set<String> set) {
            ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(set);
            if (!SystemSessionProperties.isOptimizeDynamicFilterGeneration(this.session) || (SystemSessionProperties.isOptimizeDynamicFilterGeneration(this.session) && !hasHighSelectivity(joinNode.getRight()))) {
                addAll.addAll(joinNode.getDynamicFilters().keySet());
            }
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters = (PlanWithConsumedDynamicFilters) joinNode.getLeft().accept(this, addAll.build());
            Set<String> consumedDynamicFilterIds = planWithConsumedDynamicFilters.getConsumedDynamicFilterIds();
            Map map = (Map) joinNode.getDynamicFilters().entrySet().stream().filter(entry -> {
                return consumedDynamicFilterIds.contains(entry.getKey());
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            }));
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters2 = (PlanWithConsumedDynamicFilters) joinNode.getRight().accept(this, set);
            HashSet hashSet = new HashSet(planWithConsumedDynamicFilters2.getConsumedDynamicFilterIds());
            hashSet.addAll(consumedDynamicFilterIds);
            hashSet.removeAll(map.keySet());
            PlanNode node = planWithConsumedDynamicFilters.getNode();
            PlanNode node2 = planWithConsumedDynamicFilters2.getNode();
            return (node.equals(joinNode.getLeft()) && node2.equals(joinNode.getRight()) && map.equals(joinNode.getDynamicFilters())) ? new PlanWithConsumedDynamicFilters(joinNode, ImmutableSet.copyOf(hashSet)) : new PlanWithConsumedDynamicFilters(new JoinNode(joinNode.getId(), joinNode.getType(), node, node2, joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), map), ImmutableSet.copyOf(hashSet));
        }

        @Override // io.prestosql.sql.planner.plan.InternalPlanVisitor
        public PlanWithConsumedDynamicFilters visitSemiJoin(SemiJoinNode semiJoinNode, Set<String> set) {
            Optional empty;
            if (!semiJoinNode.getDynamicFilterId().isPresent()) {
                return visitPlan((PlanNode) semiJoinNode, set);
            }
            String str = semiJoinNode.getDynamicFilterId().get();
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters = (PlanWithConsumedDynamicFilters) semiJoinNode.getSource().accept(this, ImmutableSet.builder().add(str).addAll(set).build());
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters2 = (PlanWithConsumedDynamicFilters) semiJoinNode.getFilteringSource().accept(this, set);
            HashSet hashSet = new HashSet(planWithConsumedDynamicFilters2.getConsumedDynamicFilterIds());
            hashSet.addAll(planWithConsumedDynamicFilters.getConsumedDynamicFilterIds());
            if (hashSet.contains(str)) {
                hashSet.remove(str);
                empty = Optional.of(str);
            } else {
                empty = Optional.empty();
            }
            PlanNode node = planWithConsumedDynamicFilters.getNode();
            PlanNode node2 = planWithConsumedDynamicFilters2.getNode();
            return (node.equals(semiJoinNode.getSource()) && node2.equals(semiJoinNode.getFilteringSource()) && empty.equals(semiJoinNode.getDynamicFilterId())) ? new PlanWithConsumedDynamicFilters(semiJoinNode, ImmutableSet.copyOf(hashSet)) : new PlanWithConsumedDynamicFilters(new SemiJoinNode(semiJoinNode.getId(), node, node2, semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol(), semiJoinNode.getSemiJoinOutput(), semiJoinNode.getSourceHashSymbol(), semiJoinNode.getFilteringSourceHashSymbol(), semiJoinNode.getDistributionType(), empty), ImmutableSet.copyOf(hashSet));
        }

        public PlanWithConsumedDynamicFilters visitFilter(FilterNode filterNode, Set<String> set) {
            RowExpression removeAllDynamicFilters;
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters = (PlanWithConsumedDynamicFilters) filterNode.getSource().accept(this, set);
            RowExpression predicate = filterNode.getPredicate();
            ImmutableSet.Builder<String> addAll = ImmutableSet.builder().addAll(planWithConsumedDynamicFilters.getConsumedDynamicFilterIds());
            PlanNode node = planWithConsumedDynamicFilters.getNode();
            if (node instanceof TableScanNode) {
                DynamicFilters.ExtractResult extractDynamicFilters = DynamicFilters.extractDynamicFilters(predicate);
                if (!SystemSessionProperties.isOptimizeDynamicFilterGeneration(this.session) || highSelectivity(filterNode)) {
                    removeAllDynamicFilters = removeDynamicFilters(predicate, set, addAll);
                } else {
                    removeAllDynamicFilters = removeDynamicFilters(predicate, ImmutableSet.of(), addAll);
                    extractDynamicFilters.getDynamicConjuncts().forEach(descriptor -> {
                        this.removedDynamicFilterIds.add(descriptor.getId());
                    });
                }
            } else {
                removeAllDynamicFilters = removeAllDynamicFilters(predicate);
            }
            return LogicalRowExpressions.TRUE_CONSTANT.equals(removeAllDynamicFilters) ? new PlanWithConsumedDynamicFilters(node, addAll.build()) : (predicate.equals(removeAllDynamicFilters) && node == filterNode.getSource()) ? new PlanWithConsumedDynamicFilters(filterNode, addAll.build()) : new PlanWithConsumedDynamicFilters(new FilterNode(filterNode.getId(), node, removeAllDynamicFilters), addAll.build());
        }

        private boolean hasHighSelectivity(PlanNode planNode) {
            Optional empty = Optional.empty();
            if (planNode instanceof ProjectNode) {
                return hasHighSelectivity(((ProjectNode) planNode).getSource());
            }
            if ((planNode instanceof ExchangeNode) && planNode.getSources().size() == 1) {
                return hasHighSelectivity((PlanNode) planNode.getSources().get(0));
            }
            Optional empty2 = Optional.empty();
            if (planNode instanceof TableScanNode) {
                empty = Optional.of(planNode);
                empty2 = ((TableScanNode) empty.get()).getPredicate();
            }
            if (planNode instanceof FilterNode) {
                PlanNode source = ((FilterNode) planNode).getSource();
                if (source instanceof TableScanNode) {
                    empty = Optional.of(source);
                    empty2 = Optional.of(((FilterNode) planNode).getPredicate());
                }
            }
            if (!empty.isPresent()) {
                return false;
            }
            if (empty2.isPresent() && LogicalRowExpressions.extractConjuncts((RowExpression) empty2.get()).stream().filter(DynamicFilters::isDynamicFilter).count() > 0) {
                return false;
            }
            if ((planNode instanceof TableScanNode) && ((TableScanNode) empty.get()).getEnforcedConstraint().isAll()) {
                return true;
            }
            Estimate rowCount = this.metadata.getTableStatistics(this.session, ((TableScanNode) empty.get()).getTable(), Constraint.alwaysTrue()).getRowCount();
            PlanNodeStatsEstimate stats = RemoveUnsupportedDynamicFilters.this.statsProvider.getStats(planNode);
            if (stats.isOutputRowCountUnknown() || rowCount.isUnknown()) {
                return false;
            }
            return stats.getOutputRowCount() > ((double) SystemSessionProperties.getDynamicFilteringMaxSize(this.session)) || stats.getOutputRowCount() / rowCount.getValue() > RemoveUnsupportedDynamicFilters.DEFAULT_GENERATE_SELECTIVITY_THRESHOLD;
        }

        private boolean highSelectivity(FilterNode filterNode) {
            Estimate rowCount = this.metadata.getTableStatistics(this.session, filterNode.getSource().getTable(), Constraint.alwaysTrue()).getRowCount();
            PlanNodeStatsEstimate stats = RemoveUnsupportedDynamicFilters.this.statsProvider.getStats(filterNode);
            return stats.isOutputRowCountUnknown() || rowCount.isUnknown() || stats.getOutputRowCount() > ((double) SystemSessionProperties.getDynamicFilteringMaxSize(this.session)) || stats.getOutputRowCount() / rowCount.getValue() > 0.01d;
        }

        private RowExpression removeDynamicFilters(RowExpression rowExpression, Set<String> set, ImmutableSet.Builder<String> builder) {
            return RemoveUnsupportedDynamicFilters.this.logicalRowExpressions.combineConjuncts((Collection) LogicalRowExpressions.extractConjuncts(rowExpression).stream().map(this::removeNestedDynamicFilters).filter(rowExpression2 -> {
                return ((Boolean) DynamicFilters.getDescriptor(rowExpression2).map(descriptor -> {
                    if (!(descriptor.getInput() instanceof VariableReferenceExpression) || !set.contains(descriptor.getId())) {
                        return false;
                    }
                    builder.add(descriptor.getId());
                    return true;
                }).orElse(true)).booleanValue();
            }).collect(ImmutableList.toImmutableList()));
        }

        private RowExpression removeAllDynamicFilters(RowExpression rowExpression) {
            RowExpression removeNestedDynamicFilters = removeNestedDynamicFilters(rowExpression);
            DynamicFilters.ExtractResult extractDynamicFilters = DynamicFilters.extractDynamicFilters(removeNestedDynamicFilters);
            return extractDynamicFilters.getDynamicConjuncts().isEmpty() ? removeNestedDynamicFilters : RemoveUnsupportedDynamicFilters.this.logicalRowExpressions.combineConjuncts(extractDynamicFilters.getStaticConjuncts());
        }

        private RowExpression removeNestedDynamicFilters(RowExpression rowExpression) {
            return RowExpressionTreeRewriter.rewriteWith(new RowExpressionRewriter<Void>() { // from class: io.prestosql.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters.Rewriter.1
                public RowExpression rewriteSpecialForm(SpecialForm specialForm, Void r6, RowExpressionTreeRewriter<Void> rowExpressionTreeRewriter) {
                    if (specialForm.getForm() != SpecialForm.Form.AND || specialForm.getForm() != SpecialForm.Form.OR) {
                        return specialForm;
                    }
                    SpecialForm specialForm2 = (SpecialForm) rowExpressionTreeRewriter.defaultRewrite(specialForm, r6);
                    boolean z = specialForm != specialForm2;
                    ImmutableList.Builder builder = ImmutableList.builder();
                    if (DynamicFilters.isDynamicFilter((RowExpression) specialForm2.getArguments().get(0))) {
                        builder.add(LogicalRowExpressions.TRUE_CONSTANT);
                        z = true;
                    } else {
                        builder.add(specialForm2.getArguments().get(0));
                    }
                    if (DynamicFilters.isDynamicFilter((RowExpression) specialForm2.getArguments().get(1))) {
                        builder.add(LogicalRowExpressions.TRUE_CONSTANT);
                        z = true;
                    } else {
                        builder.add(specialForm2.getArguments().get(1));
                    }
                    return !z ? specialForm : RemoveUnsupportedDynamicFilters.this.logicalRowExpressions.combinePredicates(specialForm.getForm(), builder.build());
                }

                public /* bridge */ /* synthetic */ RowExpression rewriteSpecialForm(SpecialForm specialForm, Object obj, RowExpressionTreeRewriter rowExpressionTreeRewriter) {
                    return rewriteSpecialForm(specialForm, (Void) obj, (RowExpressionTreeRewriter<Void>) rowExpressionTreeRewriter);
                }
            }, rowExpression);
        }
    }

    public RemoveUnsupportedDynamicFilters(Metadata metadata, StatsCalculator statsCalculator) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.statsCalculator = (StatsCalculator) Objects.requireNonNull(statsCalculator, "statsCalculator is null");
        this.logicalRowExpressions = new LogicalRowExpressions(new RowExpressionDeterminismEvaluator(metadata), new FunctionResolution(metadata.getFunctionAndTypeManager()), metadata.getFunctionAndTypeManager());
    }

    @Override // io.prestosql.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        this.removedDynamicFilterIds = new HashSet();
        this.statsProvider = new CachingStatsProvider(this.statsCalculator, session, planSymbolAllocator.getTypes());
        return SimplePlanRewriter.rewriteWith(new RemoveFilterVisitor(this.removedDynamicFilterIds, this.metadata), ((PlanWithConsumedDynamicFilters) planNode.accept(new Rewriter(session, this.metadata, this.removedDynamicFilterIds), ImmutableSet.of())).getNode(), null);
    }
}
