package io.prestosql.sql.planner;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.plan.AggregationNode;
import io.prestosql.spi.plan.FilterNode;
import io.prestosql.spi.plan.JoinNode;
import io.prestosql.spi.plan.LimitNode;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.plan.TableScanNode;
import io.prestosql.spi.plan.TopNNode;
import io.prestosql.spi.plan.UnionNode;
import io.prestosql.spi.plan.ValuesNode;
import io.prestosql.spi.plan.WindowNode;
import io.prestosql.spi.predicate.Domain;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.spi.relation.RowExpression;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.optimizations.JoinNodeUtils;
import io.prestosql.sql.planner.optimizations.SetOperationNodeUtils;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.DistinctLimitNode;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.InternalPlanVisitor;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.SortNode;
import io.prestosql.sql.planner.plan.SpatialJoinNode;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.SymbolReference;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/* loaded from: input_file:io/prestosql/sql/planner/EffectivePredicateExtractor.class */
public class EffectivePredicateExtractor {
    private static final Predicate<Map.Entry<Symbol, ? extends Expression>> SYMBOL_MATCHES_EXPRESSION = entry -> {
        return ((Expression) entry.getValue()).equals(SymbolUtils.toSymbolReference((Symbol) entry.getKey()));
    };
    private static final Function<Map.Entry<Symbol, ? extends Expression>, Expression> ENTRY_TO_EQUALITY = entry -> {
        return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, SymbolUtils.toSymbolReference((Symbol) entry.getKey()), (Expression) entry.getValue());
    };
    private final ExpressionDomainTranslator domainTranslator;
    private final Metadata metadata;
    private final boolean useTableProperties;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.prestosql.sql.planner.EffectivePredicateExtractor$1, reason: invalid class name */
    /* loaded from: input_file:io/prestosql/sql/planner/EffectivePredicateExtractor$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$prestosql$spi$plan$JoinNode$Type;

        static {
            try {
                $SwitchMap$io$prestosql$sql$planner$plan$SpatialJoinNode$Type[SpatialJoinNode.Type.INNER.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$prestosql$sql$planner$plan$SpatialJoinNode$Type[SpatialJoinNode.Type.LEFT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            $SwitchMap$io$prestosql$spi$plan$JoinNode$Type = new int[JoinNode.Type.values().length];
            try {
                $SwitchMap$io$prestosql$spi$plan$JoinNode$Type[JoinNode.Type.INNER.ordinal()] = 1;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$io$prestosql$spi$plan$JoinNode$Type[JoinNode.Type.LEFT.ordinal()] = 2;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$io$prestosql$spi$plan$JoinNode$Type[JoinNode.Type.RIGHT.ordinal()] = 3;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$io$prestosql$spi$plan$JoinNode$Type[JoinNode.Type.FULL.ordinal()] = 4;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    /* loaded from: input_file:io/prestosql/sql/planner/EffectivePredicateExtractor$Visitor.class */
    private static class Visitor extends InternalPlanVisitor<Expression, Void> {
        private final ExpressionDomainTranslator domainTranslator;
        private final Metadata metadata;
        private final Session session;
        private final TypeProvider types;
        private final TypeAnalyzer typeAnalyzer;
        private final boolean useTableProperties;

        public Visitor(ExpressionDomainTranslator expressionDomainTranslator, Metadata metadata, Session session, TypeProvider typeProvider, TypeAnalyzer typeAnalyzer, boolean z) {
            this.domainTranslator = (ExpressionDomainTranslator) Objects.requireNonNull(expressionDomainTranslator, "domainTranslator is null");
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
            this.session = (Session) Objects.requireNonNull(session, "session is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
            this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
            this.useTableProperties = z;
        }

        public Expression visitPlan(PlanNode planNode, Void r4) {
            return BooleanLiteral.TRUE_LITERAL;
        }

        public Expression visitAggregation(AggregationNode aggregationNode, Void r6) {
            return aggregationNode.getGroupingKeys().isEmpty() ? BooleanLiteral.TRUE_LITERAL : pullExpressionThroughSymbols((Expression) aggregationNode.getSource().accept(this, r6), aggregationNode.getGroupingKeys());
        }

        public Expression visitFilter(FilterNode filterNode, Void r7) {
            return ExpressionUtils.combineConjuncts(ExpressionUtils.filterDeterministicConjuncts(OriginalExpressionUtils.castToExpression(filterNode.getPredicate())), (Expression) filterNode.getSource().accept(this, r7));
        }

        @Override // io.prestosql.sql.planner.plan.InternalPlanVisitor
        public Expression visitExchange(ExchangeNode exchangeNode, Void r6) {
            return deriveCommonPredicates(exchangeNode, num -> {
                HashMap hashMap = new HashMap();
                for (int i = 0; i < exchangeNode.getInputs().get(num.intValue()).size(); i++) {
                    hashMap.put(exchangeNode.getOutputSymbols().get(i), SymbolUtils.toSymbolReference(exchangeNode.getInputs().get(num.intValue()).get(i)));
                }
                return hashMap.entrySet();
            });
        }

        public Expression visitProject(ProjectNode projectNode, Void r6) {
            return pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts((Collection<Expression>) ImmutableList.builder().addAll((List) Maps.transformValues(projectNode.getAssignments().getMap(), OriginalExpressionUtils::castToExpression).entrySet().stream().filter(EffectivePredicateExtractor.SYMBOL_MATCHES_EXPRESSION.negate()).map(EffectivePredicateExtractor.ENTRY_TO_EQUALITY).collect(ImmutableList.toImmutableList())).add((Expression) projectNode.getSource().accept(this, r6)).build()), projectNode.getOutputSymbols());
        }

        public Expression visitTopN(TopNNode topNNode, Void r6) {
            return (Expression) topNNode.getSource().accept(this, r6);
        }

        public Expression visitLimit(LimitNode limitNode, Void r6) {
            return (Expression) limitNode.getSource().accept(this, r6);
        }

        @Override // io.prestosql.sql.planner.plan.InternalPlanVisitor
        public Expression visitAssignUniqueId(AssignUniqueId assignUniqueId, Void r6) {
            return (Expression) assignUniqueId.getSource().accept(this, r6);
        }

        @Override // io.prestosql.sql.planner.plan.InternalPlanVisitor
        public Expression visitDistinctLimit(DistinctLimitNode distinctLimitNode, Void r6) {
            return (Expression) distinctLimitNode.getSource().accept(this, r6);
        }

        public Expression visitTableScan(TableScanNode tableScanNode, Void r7) {
            ImmutableBiMap inverse = ImmutableBiMap.copyOf(tableScanNode.getAssignments()).inverse();
            TupleDomain<ColumnHandle> enforcedConstraint = tableScanNode.getEnforcedConstraint();
            if (this.useTableProperties) {
                enforcedConstraint = this.metadata.getTableProperties(this.session, tableScanNode.getTable()).getPredicate();
            }
            ExpressionDomainTranslator expressionDomainTranslator = this.domainTranslator;
            TupleDomain simplify = enforcedConstraint.simplify();
            inverse.getClass();
            return expressionDomainTranslator.toPredicate(simplify.transform((v1) -> {
                return r2.get(v1);
            }));
        }

        @Override // io.prestosql.sql.planner.plan.InternalPlanVisitor
        public Expression visitSort(SortNode sortNode, Void r6) {
            return (Expression) sortNode.getSource().accept(this, r6);
        }

        public Expression visitWindow(WindowNode windowNode, Void r6) {
            return (Expression) windowNode.getSource().accept(this, r6);
        }

        public Expression visitUnion(UnionNode unionNode, Void r6) {
            return deriveCommonPredicates(unionNode, num -> {
                return SetOperationNodeUtils.outputSymbolMap(unionNode, num.intValue()).entries();
            });
        }

        public Expression visitJoin(JoinNode joinNode, Void r11) {
            Expression expression = (Expression) joinNode.getLeft().accept(this, r11);
            Expression expression2 = (Expression) joinNode.getRight().accept(this, r11);
            List list = (List) joinNode.getCriteria().stream().map(JoinNodeUtils::toExpression).collect(ImmutableList.toImmutableList());
            switch (AnonymousClass1.$SwitchMap$io$prestosql$spi$plan$JoinNode$Type[joinNode.getType().ordinal()]) {
                case 1:
                    return pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts((Collection<Expression>) ImmutableList.builder().add(expression).add(expression2).add(ExpressionUtils.combineConjuncts(list)).add(joinNode.getFilter().map(OriginalExpressionUtils::castToExpression).orElse(BooleanLiteral.TRUE_LITERAL)).build()), joinNode.getOutputSymbols());
                case 2:
                    ImmutableList.Builder add = ImmutableList.builder().add(pullExpressionThroughSymbols(expression, joinNode.getOutputSymbols()));
                    List<Expression> extractConjuncts = ExpressionUtils.extractConjuncts(expression2);
                    List outputSymbols = joinNode.getOutputSymbols();
                    List outputSymbols2 = joinNode.getRight().getOutputSymbols();
                    outputSymbols2.getClass();
                    ImmutableList.Builder addAll = add.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts, outputSymbols, (v1) -> {
                        return r6.contains(v1);
                    }));
                    List outputSymbols3 = joinNode.getOutputSymbols();
                    List outputSymbols4 = joinNode.getRight().getOutputSymbols();
                    outputSymbols4.getClass();
                    return ExpressionUtils.combineConjuncts((Collection<Expression>) addAll.addAll(pullNullableConjunctsThroughOuterJoin(list, outputSymbols3, (v1) -> {
                        return r6.contains(v1);
                    })).build());
                case 3:
                    ImmutableList.Builder add2 = ImmutableList.builder().add(pullExpressionThroughSymbols(expression2, joinNode.getOutputSymbols()));
                    List<Expression> extractConjuncts2 = ExpressionUtils.extractConjuncts(expression);
                    List outputSymbols5 = joinNode.getOutputSymbols();
                    List outputSymbols6 = joinNode.getLeft().getOutputSymbols();
                    outputSymbols6.getClass();
                    ImmutableList.Builder addAll2 = add2.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts2, outputSymbols5, (v1) -> {
                        return r6.contains(v1);
                    }));
                    List outputSymbols7 = joinNode.getOutputSymbols();
                    List outputSymbols8 = joinNode.getLeft().getOutputSymbols();
                    outputSymbols8.getClass();
                    return ExpressionUtils.combineConjuncts((Collection<Expression>) addAll2.addAll(pullNullableConjunctsThroughOuterJoin(list, outputSymbols7, (v1) -> {
                        return r6.contains(v1);
                    })).build());
                case 4:
                    ImmutableList.Builder builder = ImmutableList.builder();
                    List<Expression> extractConjuncts3 = ExpressionUtils.extractConjuncts(expression);
                    List outputSymbols9 = joinNode.getOutputSymbols();
                    List outputSymbols10 = joinNode.getLeft().getOutputSymbols();
                    outputSymbols10.getClass();
                    ImmutableList.Builder addAll3 = builder.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts3, outputSymbols9, (v1) -> {
                        return r6.contains(v1);
                    }));
                    List<Expression> extractConjuncts4 = ExpressionUtils.extractConjuncts(expression2);
                    List outputSymbols11 = joinNode.getOutputSymbols();
                    List outputSymbols12 = joinNode.getRight().getOutputSymbols();
                    outputSymbols12.getClass();
                    ImmutableList.Builder addAll4 = addAll3.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts4, outputSymbols11, (v1) -> {
                        return r6.contains(v1);
                    }));
                    List outputSymbols13 = joinNode.getOutputSymbols();
                    List outputSymbols14 = joinNode.getLeft().getOutputSymbols();
                    outputSymbols14.getClass();
                    List outputSymbols15 = joinNode.getRight().getOutputSymbols();
                    outputSymbols15.getClass();
                    return ExpressionUtils.combineConjuncts((Collection<Expression>) addAll4.addAll(pullNullableConjunctsThroughOuterJoin(list, outputSymbols13, (v1) -> {
                        return r6.contains(v1);
                    }, (v1) -> {
                        return r6.contains(v1);
                    })).build());
                default:
                    throw new UnsupportedOperationException("Unknown join type: " + joinNode.getType());
            }
        }

        public Expression visitValues(ValuesNode valuesNode, Void r7) {
            if (valuesNode.getOutputSymbols().isEmpty()) {
                return BooleanLiteral.TRUE_LITERAL;
            }
            Map<NodeRef<Expression>, Type> types = this.typeAnalyzer.getTypes(this.session, this.types, (List) valuesNode.getRows().stream().flatMap((v0) -> {
                return v0.stream();
            }).map(OriginalExpressionUtils::castToExpression).collect(Collectors.toList()));
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (int i = 0; i < valuesNode.getOutputSymbols().size(); i++) {
                Symbol symbol = (Symbol) valuesNode.getOutputSymbols().get(i);
                Type type = this.types.get(symbol);
                ImmutableList.Builder builder2 = ImmutableList.builder();
                boolean z = false;
                boolean z2 = false;
                int i2 = 0;
                while (true) {
                    if (i2 >= valuesNode.getRows().size()) {
                        break;
                    }
                    Expression castToExpression = OriginalExpressionUtils.castToExpression((RowExpression) ((List) valuesNode.getRows().get(i2)).get(i));
                    if (!ExpressionDeterminismEvaluator.isDeterministic(castToExpression)) {
                        z2 = true;
                        break;
                    }
                    Object optimize = ExpressionInterpreter.expressionOptimizer(castToExpression, this.metadata, this.session, types).optimize(NoOpSymbolResolver.INSTANCE);
                    if (optimize instanceof Expression) {
                        return BooleanLiteral.TRUE_LITERAL;
                    }
                    if (optimize == null) {
                        z = true;
                    } else {
                        builder2.add(optimize);
                    }
                    i2++;
                }
                if (!z2) {
                    ImmutableList build = builder2.build();
                    Domain none = Domain.none(type);
                    if (!build.isEmpty()) {
                        none = none.union(Domain.multipleValues(type, build));
                    }
                    if (z) {
                        none = none.union(Domain.onlyNull(type));
                    }
                    builder.put(symbol, none);
                }
            }
            return this.domainTranslator.toPredicate(TupleDomain.withColumnDomains(builder.build()).simplify());
        }

        private static Iterable<Expression> pullNullableConjunctsThroughOuterJoin(List<Expression> list, Collection<Symbol> collection, Predicate<Symbol>... predicateArr) {
            return (Iterable) list.stream().map(expression -> {
                return pullExpressionThroughSymbols(expression, collection);
            }).map(expression2 -> {
                return SymbolsExtractor.extractAll(expression2).isEmpty() ? BooleanLiteral.TRUE_LITERAL : expression2;
            }).map(ExpressionUtils.expressionOrNullSymbols(predicateArr)).collect(ImmutableList.toImmutableList());
        }

        @Override // io.prestosql.sql.planner.plan.InternalPlanVisitor
        public Expression visitSemiJoin(SemiJoinNode semiJoinNode, Void r6) {
            return (Expression) semiJoinNode.getSource().accept(this, r6);
        }

        @Override // io.prestosql.sql.planner.plan.InternalPlanVisitor
        public Expression visitSpatialJoin(SpatialJoinNode spatialJoinNode, Void r11) {
            Expression expression = (Expression) spatialJoinNode.getLeft().accept(this, r11);
            Expression expression2 = (Expression) spatialJoinNode.getRight().accept(this, r11);
            switch (spatialJoinNode.getType()) {
                case INNER:
                    return ExpressionUtils.combineConjuncts((Collection<Expression>) ImmutableList.builder().add(pullExpressionThroughSymbols(expression, spatialJoinNode.getOutputSymbols())).add(pullExpressionThroughSymbols(expression2, spatialJoinNode.getOutputSymbols())).build());
                case LEFT:
                    ImmutableList.Builder add = ImmutableList.builder().add(pullExpressionThroughSymbols(expression, spatialJoinNode.getOutputSymbols()));
                    List<Expression> extractConjuncts = ExpressionUtils.extractConjuncts(expression2);
                    List<Symbol> outputSymbols = spatialJoinNode.getOutputSymbols();
                    List outputSymbols2 = spatialJoinNode.getRight().getOutputSymbols();
                    outputSymbols2.getClass();
                    return ExpressionUtils.combineConjuncts((Collection<Expression>) add.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts, outputSymbols, (v1) -> {
                        return r6.contains(v1);
                    })).build());
                default:
                    throw new IllegalArgumentException("Unsupported spatial join type: " + spatialJoinNode.getType());
            }
        }

        private Expression deriveCommonPredicates(PlanNode planNode, Function<Integer, Collection<Map.Entry<Symbol, SymbolReference>>> function) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < planNode.getSources().size(); i++) {
                arrayList.add(ImmutableSet.copyOf(ExpressionUtils.extractConjuncts(pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts((Collection<Expression>) ImmutableList.builder().addAll((List) function.apply(Integer.valueOf(i)).stream().filter(EffectivePredicateExtractor.SYMBOL_MATCHES_EXPRESSION.negate()).map(EffectivePredicateExtractor.ENTRY_TO_EQUALITY).collect(ImmutableList.toImmutableList())).add((Expression) ((PlanNode) planNode.getSources().get(i)).accept(this, (Object) null)).build()), planNode.getOutputSymbols()))));
            }
            Iterator it = arrayList.iterator();
            Sets.SetView setView = (Set) it.next();
            while (true) {
                Sets.SetView setView2 = setView;
                if (!it.hasNext()) {
                    return ExpressionUtils.combineConjuncts((Collection<Expression>) setView2);
                }
                setView = Sets.intersection(setView2, (Set) it.next());
            }
        }

        private static List<Expression> pullExpressionsThroughSymbols(List<Expression> list, Collection<Symbol> collection) {
            return (List) list.stream().map(expression -> {
                return pullExpressionThroughSymbols(expression, collection);
            }).collect(ImmutableList.toImmutableList());
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> collection) {
            Expression rewriteExpression;
            EqualityInference createEqualityInference = EqualityInference.createEqualityInference(expression);
            ImmutableList.Builder builder = ImmutableList.builder();
            for (Expression expression2 : EqualityInference.nonInferrableConjuncts(expression)) {
                if (ExpressionDeterminismEvaluator.isDeterministic(expression2) && (rewriteExpression = createEqualityInference.rewriteExpression(expression2, Predicates.in(collection))) != null) {
                    builder.add(rewriteExpression);
                }
            }
            builder.addAll(createEqualityInference.generateEqualitiesPartitionedBy(Predicates.in(collection)).getScopeEqualities());
            return ExpressionUtils.combineConjuncts((Collection<Expression>) builder.build());
        }
    }

    public EffectivePredicateExtractor(ExpressionDomainTranslator expressionDomainTranslator, Metadata metadata, boolean z) {
        this.domainTranslator = (ExpressionDomainTranslator) Objects.requireNonNull(expressionDomainTranslator, "domainTranslator is null");
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.useTableProperties = z;
    }

    public Expression extract(Session session, PlanNode planNode, TypeProvider typeProvider, TypeAnalyzer typeAnalyzer) {
        return (Expression) planNode.accept(new Visitor(this.domainTranslator, this.metadata, session, typeProvider, typeAnalyzer, this.useTableProperties), (Object) null);
    }
}
