package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.function.StandardFunctionResolution;
import io.prestosql.spi.plan.AggregationNode;
import io.prestosql.spi.plan.Assignments;
import io.prestosql.spi.plan.CTEScanNode;
import io.prestosql.spi.plan.FilterNode;
import io.prestosql.spi.plan.JoinNode;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.plan.TableScanNode;
import io.prestosql.spi.relation.CallExpression;
import io.prestosql.spi.relation.RowExpression;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.SymbolUtils;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.StarTreeAggregationRule;
import io.prestosql.sql.planner.plan.ApplyNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.relational.FunctionResolution;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.GenericLiteral;
import io.prestosql.sql.tree.InPredicate;
import io.prestosql.sql.tree.LogicalBinaryExpression;
import io.prestosql.sql.tree.SymbolReference;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/TransformUnCorrelatedInPredicateSubQuerySelfJoinToAggregate.class */
public class TransformUnCorrelatedInPredicateSubQuerySelfJoinToAggregate implements Rule<ApplyNode> {
    private static final Capture<ProjectNode> PROJECT_NODE = Capture.newCapture();
    private static final Pattern<ProjectNode> PROJECT_NODE_PATTERN = Patterns.project().capturedAs(PROJECT_NODE);
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode().with(Pattern.empty(Patterns.Apply.correlation())).with(Patterns.Apply.subQuery().matching(PROJECT_NODE_PATTERN));
    private final StandardFunctionResolution functionResolution;

    public TransformUnCorrelatedInPredicateSubQuerySelfJoinToAggregate(Metadata metadata) {
        Objects.requireNonNull(metadata, "metadata is null");
        this.functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager());
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return ((Boolean) session.getSystemProperty(SystemSessionProperties.TRANSFORM_SELF_JOIN_TO_GROUPBY, Boolean.class)).booleanValue();
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(ApplyNode applyNode, Captures captures, Rule.Context context) {
        if (applyNode.getSubqueryAssignments().size() == 1 && (OriginalExpressionUtils.castToExpression((RowExpression) Iterables.getOnlyElement(applyNode.getSubqueryAssignments().getExpressions())) instanceof InPredicate)) {
            Optional<ProjectNode> transformProjectNode = transformProjectNode(context, (ProjectNode) captures.get(PROJECT_NODE));
            return transformProjectNode.isPresent() ? Rule.Result.ofPlanNode(new ApplyNode(applyNode.getId(), applyNode.getInput(), transformProjectNode.get(), applyNode.getSubqueryAssignments(), applyNode.getCorrelation(), applyNode.getOriginSubquery())) : Rule.Result.empty();
        }
        return Rule.Result.empty();
    }

    private PlanNode getChildFilterNode(Rule.Context context, PlanNode planNode) {
        ProjectNode resolve = context.getLookup().resolve(planNode);
        return resolve instanceof ProjectNode ? getChildFilterNode(context, resolve.getSource()) : resolve;
    }

    private Optional<ProjectNode> transformProjectNode(Rule.Context context, ProjectNode projectNode) {
        TableScanNode tableScanNode;
        List list;
        AggregationNode.GroupingSetDescriptor groupingSetDescriptor;
        if (projectNode.getOutputSymbols().size() > 1) {
            return Optional.empty();
        }
        PlanNode resolve = context.getLookup().resolve(projectNode.getSource());
        if (resolve instanceof CTEScanNode) {
            resolve = getChildFilterNode(context, context.getLookup().resolve(((CTEScanNode) resolve).getSource()));
        }
        if (!(resolve instanceof FilterNode) || !(context.getLookup().resolve(((FilterNode) resolve).getSource()) instanceof JoinNode)) {
            return Optional.empty();
        }
        FilterNode filterNode = (FilterNode) resolve;
        Expression castToExpression = OriginalExpressionUtils.castToExpression(filterNode.getPredicate());
        ArrayList arrayList = new ArrayList();
        getAllSymbols(castToExpression, arrayList);
        JoinNode resolve2 = context.getLookup().resolve(((FilterNode) resolve).getSource());
        if (!isSelfJoin(projectNode, castToExpression, resolve2, context.getLookup())) {
            PlanNode resolve3 = context.getLookup().resolve(resolve2.getLeft());
            boolean z = false;
            if (resolve3 instanceof ProjectNode) {
                Optional<ProjectNode> transformProjectNode = transformProjectNode(context, (ProjectNode) resolve3);
                if (transformProjectNode.isPresent()) {
                    resolve2 = new JoinNode(resolve2.getId(), resolve2.getType(), transformProjectNode.get(), resolve2.getRight(), resolve2.getCriteria(), resolve2.getOutputSymbols(), resolve2.getFilter(), resolve2.getLeftHashSymbol(), resolve2.getRightHashSymbol(), resolve2.getDistributionType(), resolve2.isSpillable(), resolve2.getDynamicFilters());
                    z = true;
                }
            }
            PlanNode resolve4 = context.getLookup().resolve(resolve2.getRight());
            if (resolve4 instanceof ProjectNode) {
                Optional<ProjectNode> transformProjectNode2 = transformProjectNode(context, (ProjectNode) resolve4);
                if (transformProjectNode2.isPresent()) {
                    resolve2 = new JoinNode(resolve2.getId(), resolve2.getType(), resolve2.getLeft(), transformProjectNode2.get(), resolve2.getCriteria(), resolve2.getOutputSymbols(), resolve2.getFilter(), resolve2.getLeftHashSymbol(), resolve2.getRightHashSymbol(), resolve2.getDistributionType(), resolve2.isSpillable(), resolve2.getDynamicFilters());
                    z = true;
                }
            }
            if (z) {
                return Optional.of(new ProjectNode(projectNode.getId(), new FilterNode(filterNode.getId(), resolve2, filterNode.getPredicate()), projectNode.getAssignments()));
            }
            return Optional.empty();
        }
        TableScanNode resolve5 = context.getLookup().resolve(resolve2.getLeft());
        TableScanNode resolve6 = context.getLookup().resolve(resolve2.getRight());
        Assignments assignments = null;
        Assignments assignments2 = null;
        if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
            ProjectNode resolve7 = context.getLookup().resolve(context.getLookup().resolve(projectNode.getSource()).getSource());
            ArrayList arrayList2 = new ArrayList();
            resolve6.getOutputSymbols().forEach(symbol -> {
                arrayList2.add(symbol);
            });
            resolve5.getOutputSymbols().forEach(symbol2 -> {
                arrayList2.add(symbol2);
            });
            ArrayList arrayList3 = new ArrayList();
            for (int i = 0; i < arrayList2.size(); i++) {
                Symbol symbol3 = (Symbol) arrayList2.get(i);
                for (Symbol symbol4 : projectNode.getOutputSymbols()) {
                    if (resolve7.getAssignments().getMap().containsKey(symbol4) && OriginalExpressionUtils.castToExpression((RowExpression) resolve7.getAssignments().getMap().get(symbol4)).getName().equals(symbol3.getName())) {
                        arrayList3.add(symbol3);
                    }
                }
            }
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            for (Map.Entry entry : resolve7.getAssignments().getMap().entrySet()) {
                if (entry.getKey().equals(Iterables.getOnlyElement(projectNode.getOutputSymbols()))) {
                    hashMap.put((Symbol) entry.getKey(), (RowExpression) entry.getValue());
                }
                if (entry.getKey().equals(Iterables.getOnlyElement(projectNode.getOutputSymbols()))) {
                    hashMap2.put(Iterables.getOnlyElement(arrayList3), (RowExpression) entry.getValue());
                }
            }
            assignments = new Assignments(hashMap);
            assignments2 = new Assignments(hashMap2);
            tableScanNode = resolve5.getOutputSymbols().contains(Iterables.getOnlyElement(arrayList3)) ? resolve5 : resolve6;
            list = (List) arrayList.stream().filter(symbolReference -> {
                return tableScanNode.getOutputSymbols().contains(SymbolUtils.from(symbolReference));
            }).filter(symbolReference2 -> {
                return !arrayList3.contains(SymbolUtils.from(symbolReference2));
            }).map((v0) -> {
                return OriginalExpressionUtils.castToRowExpression(v0);
            }).collect(Collectors.toList());
            groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(arrayList3), 1, ImmutableSet.of());
        } else {
            tableScanNode = resolve5.getOutputSymbols().contains(Iterables.getOnlyElement(projectNode.getOutputSymbols())) ? resolve5 : resolve6;
            list = (List) arrayList.stream().filter(symbolReference3 -> {
                return tableScanNode.getOutputSymbols().contains(SymbolUtils.from(symbolReference3));
            }).filter(symbolReference4 -> {
                return !projectNode.getOutputSymbols().contains(SymbolUtils.from(symbolReference4));
            }).map((v0) -> {
                return OriginalExpressionUtils.castToRowExpression(v0);
            }).collect(Collectors.toList());
            groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(projectNode.getOutputSymbols()), 1, ImmutableSet.of());
        }
        AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(new CallExpression(StarTreeAggregationRule.COUNT, this.functionResolution.countFunction(), BigintType.BIGINT, list), list, true, Optional.empty(), Optional.empty(), Optional.empty());
        ImmutableMap.Builder builder = ImmutableMap.builder();
        Symbol newSymbol = context.getSymbolAllocator().newSymbol(aggregation.getFunctionCall().getDisplayName(), (Type) BigintType.BIGINT);
        builder.put(newSymbol, aggregation);
        FilterNode filterNode2 = new FilterNode(context.getIdAllocator().getNextId(), new AggregationNode(context.getIdAllocator().getNextId(), tableScanNode, builder.build(), groupingSetDescriptor, ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), OriginalExpressionUtils.castToRowExpression(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, SymbolUtils.toSymbolReference(newSymbol), new GenericLiteral("BIGINT", "1"))));
        ProjectNode projectNode2 = new ProjectNode(projectNode.getId(), filterNode2, projectNode.getAssignments());
        if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
            CTEScanNode resolve8 = context.getLookup().resolve(projectNode.getSource());
            PlanNode resolve9 = context.getLookup().resolve(resolve8.getSource());
            CTEScanNode replaceChildren = resolve8.replaceChildren(ImmutableList.of(new ProjectNode(resolve9.getId(), new ProjectNode(context.getLookup().resolve(resolve9).getId(), filterNode2, assignments2), assignments)));
            replaceChildren.setOutputSymbols(projectNode.getOutputSymbols());
            projectNode2 = new ProjectNode(projectNode.getId(), replaceChildren, projectNode.getAssignments());
        }
        return Optional.of(projectNode2);
    }

    private static boolean isSelfJoin(ProjectNode projectNode, Expression expression, JoinNode joinNode, Lookup lookup) {
        TableScanNode resolve = lookup.resolve(joinNode.getLeft());
        TableScanNode resolve2 = lookup.resolve(joinNode.getRight());
        if (joinNode.getType() != JoinNode.Type.INNER || !(resolve instanceof TableScanNode) || !(resolve2 instanceof TableScanNode) || !resolve.getTable().getFullyQualifiedName().equals(resolve2.getTable().getFullyQualifiedName()) || !(expression instanceof LogicalBinaryExpression) || !(((LogicalBinaryExpression) expression).getLeft() instanceof ComparisonExpression) || !(((LogicalBinaryExpression) expression).getRight() instanceof ComparisonExpression)) {
            return false;
        }
        SymbolReference symbolReference = SymbolUtils.toSymbolReference((Symbol) Iterables.getOnlyElement(projectNode.getOutputSymbols()));
        ComparisonExpression left = ((LogicalBinaryExpression) expression).getLeft();
        ComparisonExpression right = ((LogicalBinaryExpression) expression).getRight();
        if (lookup.resolve(projectNode.getSource()) instanceof CTEScanNode) {
            ProjectNode resolve3 = lookup.resolve(lookup.resolve(projectNode.getSource()).getSource());
            RowExpression rowExpression = resolve3.getAssignments().get((Symbol) Iterables.getOnlyElement(projectNode.getOutputSymbols()));
            for (Symbol symbol : lookup.resolve(resolve3.getSource()).getOutputSymbols()) {
                if (symbol.getName().equals(OriginalExpressionUtils.castToExpression(rowExpression).getName())) {
                    symbolReference = SymbolUtils.toSymbolReference(symbol);
                }
            }
        }
        if (left.getChildren().contains(symbolReference) && left.getOperator() == ComparisonExpression.Operator.EQUAL && right.getOperator() == ComparisonExpression.Operator.NOT_EQUAL) {
            return true;
        }
        return right.getChildren().contains(symbolReference) && right.getOperator() == ComparisonExpression.Operator.EQUAL && left.getOperator() == ComparisonExpression.Operator.NOT_EQUAL;
    }

    private static void getAllSymbols(Expression expression, List<SymbolReference> list) {
        if (expression instanceof LogicalBinaryExpression) {
            LogicalBinaryExpression logicalBinaryExpression = (LogicalBinaryExpression) expression;
            getAllSymbols(logicalBinaryExpression.getLeft(), list);
            getAllSymbols(logicalBinaryExpression.getRight(), list);
        } else if (expression instanceof ComparisonExpression) {
            ComparisonExpression comparisonExpression = (ComparisonExpression) expression;
            getAllSymbols(comparisonExpression.getLeft(), list);
            getAllSymbols(comparisonExpression.getRight(), list);
        } else if (expression instanceof SymbolReference) {
            list.add((SymbolReference) expression);
        }
    }
}
