package io.prestosql.sql.planner.optimizations;

import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.spi.plan.CTEScanNode;
import io.prestosql.spi.plan.FilterNode;
import io.prestosql.spi.plan.JoinNode;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.PlanNodeIdAllocator;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.plan.WindowNode;
import io.prestosql.spi.relation.RowExpression;
import io.prestosql.sql.planner.PlanSymbolAllocator;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.SymbolReference;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/PruneCTENodes.class */
public class PruneCTENodes implements PlanOptimizer {
    private final boolean pruneCTEWithFilter;
    private final boolean pruneCTEWithCrossJoin;
    private final boolean pruneCTEWithDynFilter;

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/PruneCTENodes$ExpressionDetails.class */
    public static class ExpressionDetails {
        int dynFilCount;
        RowExpression predicate;

        public ExpressionDetails(int i, RowExpression rowExpression) {
            this.dynFilCount = i;
            this.predicate = rowExpression;
        }

        public int getDynFilCount() {
            return this.dynFilCount;
        }

        public RowExpression getPredicate() {
            return this.predicate;
        }
    }

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/PruneCTENodes$OptimizedPlanRewriter.class */
    private static class OptimizedPlanRewriter extends SimplePlanRewriter<ExpressionDetails> {
        private boolean isNodeAlreadyVisited;
        private final boolean pruneCTEWithFilter;
        private final boolean pruneCTEWithCrossJoin;
        private final boolean pruneCTEWithDynFilter;
        private Set<Integer> cTEWithCrossJoinList;
        private final Map<Integer, Integer> cteUsageMap;
        private final Map<Integer, Integer> cteJoinDynMap;
        private final Set<Integer> cteToPrune;

        private OptimizedPlanRewriter(Boolean bool, boolean z, boolean z2, boolean z3) {
            this.cTEWithCrossJoinList = new HashSet();
            this.isNodeAlreadyVisited = bool.booleanValue();
            this.cteUsageMap = new HashMap();
            this.pruneCTEWithFilter = z;
            this.pruneCTEWithCrossJoin = z2;
            this.pruneCTEWithDynFilter = z3;
            this.cteJoinDynMap = new HashMap();
            this.cteToPrune = new HashSet();
        }

        public PlanNode visitFilter(FilterNode filterNode, SimplePlanRewriter.RewriteContext<ExpressionDetails> rewriteContext) {
            return rewriteContext.defaultRewrite(filterNode, new ExpressionDetails(rewriteContext.get() != null ? rewriteContext.get().getDynFilCount() : 0, filterNode.getPredicate()));
        }

        public PlanNode visitJoin(JoinNode joinNode, SimplePlanRewriter.RewriteContext<ExpressionDetails> rewriteContext) {
            Integer childCTERefNum;
            if (this.pruneCTEWithCrossJoin && joinNode.isCrossJoin()) {
                Integer childCTERefNum2 = getChildCTERefNum(joinNode.getLeft());
                Integer childCTERefNum3 = getChildCTERefNum(joinNode.getRight());
                if (childCTERefNum2 != null && childCTERefNum3 != null && childCTERefNum2.equals(childCTERefNum3)) {
                    this.cTEWithCrossJoinList.add(childCTERefNum2);
                }
            }
            int dynFilCount = rewriteContext.get() != null ? rewriteContext.get().getDynFilCount() : 0;
            if (this.pruneCTEWithDynFilter && (childCTERefNum = getChildCTERefNum(joinNode.getLeft())) != null && !this.cteToPrune.contains(childCTERefNum)) {
                if (!this.cteJoinDynMap.containsKey(childCTERefNum)) {
                    this.cteJoinDynMap.put(childCTERefNum, Integer.valueOf(dynFilCount + joinNode.getDynamicFilters().size()));
                } else if (this.cteJoinDynMap.get(childCTERefNum).intValue() != dynFilCount + joinNode.getDynamicFilters().size()) {
                    this.cteToPrune.add(childCTERefNum);
                }
            }
            return rewriteContext.defaultRewrite(joinNode, new ExpressionDetails(dynFilCount + joinNode.getDynamicFilters().size(), rewriteContext.get() != null ? rewriteContext.get().getPredicate() : null));
        }

        private Integer getChildCTERefNum(PlanNode planNode) {
            if (planNode instanceof CTEScanNode) {
                return ((CTEScanNode) planNode).getCommonCTERefNum();
            }
            if (planNode instanceof ProjectNode) {
                return getChildCTERefNum(((ProjectNode) planNode).getSource());
            }
            if (planNode instanceof FilterNode) {
                return getChildCTERefNum(((FilterNode) planNode).getSource());
            }
            if (planNode.getSources().size() == 1 && (planNode instanceof ExchangeNode)) {
                return getChildCTERefNum((PlanNode) planNode.getSources().get(0));
            }
            return null;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static boolean isSymbolBaseColumn(String str) {
            return (str.startsWith(StarTreeAggregationRule.SUM) || str.startsWith(StarTreeAggregationRule.AVG) || str.startsWith(StarTreeAggregationRule.COUNT) || str.startsWith(StarTreeAggregationRule.MAX) || str.startsWith(StarTreeAggregationRule.MIN)) ? false : true;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static boolean isExprBaseColumn(RowExpression rowExpression) {
            if (!OriginalExpressionUtils.isExpression(rowExpression)) {
                return false;
            }
            SymbolReference castToExpression = OriginalExpressionUtils.castToExpression(rowExpression);
            if (castToExpression instanceof SymbolReference) {
                return isSymbolBaseColumn(castToExpression.getName());
            }
            return false;
        }

        @Override // io.prestosql.sql.planner.plan.InternalPlanVisitor
        public PlanNode visitCTEScan(CTEScanNode cTEScanNode, SimplePlanRewriter.RewriteContext<ExpressionDetails> rewriteContext) {
            Integer commonCTERefNum = cTEScanNode.getCommonCTERefNum();
            if (this.pruneCTEWithCrossJoin && this.cTEWithCrossJoinList.contains(commonCTERefNum)) {
                return visitPlan((PlanNode) cTEScanNode, (SimplePlanRewriter.RewriteContext) rewriteContext).getSource();
            }
            if (!this.isNodeAlreadyVisited) {
                if (this.pruneCTEWithFilter && rewriteContext.get() != null && rewriteContext.get().getPredicate() != null) {
                    List list = cTEScanNode.getSource() instanceof ProjectNode ? (List) cTEScanNode.getSource().getAssignments().entrySet().stream().filter(entry -> {
                        return isExprBaseColumn((RowExpression) entry.getValue());
                    }).map((v0) -> {
                        return v0.getKey();
                    }).collect(Collectors.toList()) : (List) cTEScanNode.getOutputSymbols().stream().filter(symbol -> {
                        return isSymbolBaseColumn(symbol.getName());
                    }).collect(Collectors.toList());
                    Stream<Symbol> stream = SymbolsExtractor.extractUnique(rewriteContext.get().getPredicate()).stream();
                    List list2 = list;
                    list2.getClass();
                    if (stream.anyMatch((v1) -> {
                        return r1.contains(v1);
                    }) && (cTEScanNode.getSource().getSources().isEmpty() || !(cTEScanNode.getSource().getSources().get(0) instanceof WindowNode))) {
                        return visitPlan((PlanNode) cTEScanNode, (SimplePlanRewriter.RewriteContext) rewriteContext).getSource();
                    }
                }
                this.cteUsageMap.merge(commonCTERefNum, 1, (v0, v1) -> {
                    return Integer.sum(v0, v1);
                });
            } else if (this.cteUsageMap.get(commonCTERefNum).intValue() == 1 || this.cteToPrune.contains(commonCTERefNum)) {
                return visitPlan((PlanNode) cTEScanNode, (SimplePlanRewriter.RewriteContext) rewriteContext).getSource();
            }
            return visitPlan((PlanNode) cTEScanNode, (SimplePlanRewriter.RewriteContext) rewriteContext);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean isSecondTraverseRequired() {
            this.isNodeAlreadyVisited = (this.cteUsageMap.size() != 0 && this.cteUsageMap.values().stream().filter(num -> {
                return num.intValue() <= 1;
            }).count() > 0) || this.cteToPrune.size() > 0;
            return this.isNodeAlreadyVisited;
        }
    }

    public PruneCTENodes(boolean z, boolean z2, boolean z3) {
        this.pruneCTEWithFilter = z;
        this.pruneCTEWithCrossJoin = z2;
        this.pruneCTEWithDynFilter = z3;
    }

    @Override // io.prestosql.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(typeProvider, "types is null");
        Objects.requireNonNull(planSymbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        if (!SystemSessionProperties.isCTEReuseEnabled(session)) {
            return planNode;
        }
        OptimizedPlanRewriter optimizedPlanRewriter = new OptimizedPlanRewriter(false, this.pruneCTEWithFilter, this.pruneCTEWithCrossJoin, this.pruneCTEWithDynFilter);
        PlanNode rewriteWith = SimplePlanRewriter.rewriteWith(optimizedPlanRewriter, planNode);
        return optimizedPlanRewriter.isSecondTraverseRequired() ? SimplePlanRewriter.rewriteWith(optimizedPlanRewriter, rewriteWith) : rewriteWith;
    }
}
