package io.prestosql.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.cost.PlanNodeStatsEstimate;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.Constraint;
import io.prestosql.spi.metadata.TableHandle;
import io.prestosql.spi.plan.AggregationNode;
import io.prestosql.spi.plan.Assignments;
import io.prestosql.spi.plan.JoinNode;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.plan.TableScanNode;
import io.prestosql.spi.plan.ValuesNode;
import io.prestosql.spi.relation.RowExpression;
import io.prestosql.spi.statistics.ColumnStatistics;
import io.prestosql.spi.statistics.TableStatistics;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.IndexSourceNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.SymbolReference;
import io.prestosql.sql.util.SpecialCommentFormatter;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/TablePushdown.class */
public class TablePushdown implements Rule<JoinNode> {
    private static Map<String, String[]> uniqueColumnsPerTable;
    private static DIRECTION outerTableDirection;
    private static DIRECTION innerTableDirection;
    private final Metadata metadata;
    private Rule.Context ruleContext;
    private final String[] joinCriteriaStrings = new String[2];
    private static final Logger LOG = Logger.get(TablePushdown.class);
    private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(joinNode -> {
        return joinNode.getType() == JoinNode.Type.INNER;
    });
    private static Stack<NodeWithTreeDirection> outerTablePathStack = new Stack<>();
    private static Stack<NodeWithTreeDirection> innerTablePathStack = new Stack<>();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/TablePushdown$DIRECTION.class */
    public enum DIRECTION {
        LEFT,
        RIGHT
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/TablePushdown$NodeWithTreeDirection.class */
    public static class NodeWithTreeDirection {
        private final DIRECTION direction;
        private final PlanNode node;

        public NodeWithTreeDirection(PlanNode planNode, DIRECTION direction) {
            this.node = planNode;
            this.direction = direction;
        }

        public PlanNode getNode() {
            return this.node;
        }

        public DIRECTION getDirection() {
            return this.direction;
        }
    }

    public TablePushdown(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.shouldEnableTablePushdown(session);
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        uniqueColumnsPerTable = SpecialCommentFormatter.getUniqueColumnTableMap();
        this.ruleContext = context;
        Lookup lookup = this.ruleContext.getLookup();
        if (!verifyJoinConditions(joinNode, lookup)) {
            return Rule.Result.empty();
        }
        LOG.info("Table Pushdown preconditions satisfied.");
        PlanNode planRewriter = planRewriter(joinNode, lookup);
        clearAllDataStructures();
        return Rule.Result.ofPlanNode(planRewriter);
    }

    private boolean verifyJoinConditions(JoinNode joinNode, Lookup lookup) {
        List criteria = joinNode.getCriteria();
        if (criteria.isEmpty()) {
            return false;
        }
        this.joinCriteriaStrings[0] = ((JoinNode.EquiJoinClause) criteria.get(0)).getLeft().toString();
        this.joinCriteriaStrings[0] = this.joinCriteriaStrings[0].replaceAll("_\\p{Digit}+", "");
        this.joinCriteriaStrings[1] = ((JoinNode.EquiJoinClause) criteria.get(0)).getRight().toString();
        this.joinCriteriaStrings[1] = this.joinCriteriaStrings[1].replaceAll("_\\p{Digit}+", "");
        PlanNode resolveNodeFromGroupReference = resolveNodeFromGroupReference(joinNode, 0, lookup);
        PlanNode resolveNodeFromGroupReference2 = resolveNodeFromGroupReference(joinNode, 1, lookup);
        Stack<NodeWithTreeDirection> stack = new Stack<>();
        Stack<NodeWithTreeDirection> stack2 = new Stack<>();
        boolean updateStack = updateStack(resolveNodeFromGroupReference, lookup, stack);
        boolean updateStack2 = updateStack(resolveNodeFromGroupReference2, lookup, stack2);
        if (updateStack) {
            outerTablePathStack = stack;
            innerTablePathStack = stack2;
            outerTableDirection = DIRECTION.LEFT;
            innerTableDirection = DIRECTION.RIGHT;
        } else {
            if (!updateStack2) {
                return false;
            }
            outerTablePathStack = stack2;
            innerTablePathStack = stack;
            outerTableDirection = DIRECTION.RIGHT;
            innerTableDirection = DIRECTION.LEFT;
        }
        return verifyPresenceOfJoinsAndGroupBy();
    }

    private boolean updateStack(PlanNode planNode, Lookup lookup, Stack<NodeWithTreeDirection> stack) {
        stack.push(new NodeWithTreeDirection(planNode, DIRECTION.LEFT));
        if (planNode instanceof TableScanNode) {
            return isTableWithUniqueColumns((TableScanNode) planNode);
        }
        if (!(planNode instanceof JoinNode)) {
            if ((planNode instanceof ValuesNode) || (planNode instanceof IndexSourceNode)) {
                return false;
            }
            return updateStack(resolveNodeFromGroupReference(planNode, 0, lookup), lookup, stack);
        }
        if (updateStack(resolveNodeFromGroupReference(planNode, 0, lookup), lookup, stack)) {
            return true;
        }
        while (!(stack.peek().getNode() instanceof JoinNode)) {
            stack.pop();
        }
        stack.push(new NodeWithTreeDirection(stack.pop().getNode(), DIRECTION.RIGHT));
        return updateStack(resolveNodeFromGroupReference(planNode, 1, lookup), lookup, stack);
    }

    private boolean isTableWithUniqueColumns(TableScanNode tableScanNode) {
        TableHandle table = tableScanNode.getTable();
        TableStatistics tableStatistics = this.metadata.getTableStatistics(this.ruleContext.getSession(), table, Constraint.alwaysTrue());
        if (tableStatistics != null && isTableWithUniqueColumnTableStatistics(tableStatistics, table)) {
            return true;
        }
        if (uniqueColumnsPerTable.isEmpty()) {
            return false;
        }
        return isTableWithUniqueColumnsUserHint(tableScanNode);
    }

    private boolean isTableWithUniqueColumnTableStatistics(TableStatistics tableStatistics, TableHandle tableHandle) {
        boolean z = false;
        Map<String, ColumnHandle> columnHandles = this.metadata.getColumnHandles(this.ruleContext.getSession(), tableHandle);
        ColumnStatistics columnStatistics = null;
        if (columnHandles.containsKey(this.joinCriteriaStrings[0])) {
            columnStatistics = (ColumnStatistics) tableStatistics.getColumnStatistics().get(columnHandles.get(this.joinCriteriaStrings[0]));
            z = true;
        } else if (columnHandles.containsKey(this.joinCriteriaStrings[1])) {
            columnStatistics = (ColumnStatistics) tableStatistics.getColumnStatistics().get(columnHandles.get(this.joinCriteriaStrings[1]));
            z = true;
        }
        if (!z) {
            return false;
        }
        Objects.requireNonNull(columnStatistics, "Column Statistics cannot be null if the column exists for the table");
        return tableStatistics.getRowCount().getValue() == columnStatistics.getDistinctValuesCount().getValue();
    }

    private boolean isTableWithUniqueColumnsUserHint(TableScanNode tableScanNode) {
        boolean z = false;
        if (!tableNameInUniqueColumnMap(extractTableName(tableScanNode))) {
            return false;
        }
        Iterator<Map.Entry<String, String[]>> it = uniqueColumnsPerTable.entrySet().iterator();
        while (it.hasNext()) {
            if (Arrays.stream(it.next().getValue()).anyMatch(str -> {
                return str.equalsIgnoreCase(this.joinCriteriaStrings[0]) || str.equalsIgnoreCase(this.joinCriteriaStrings[1]);
            })) {
                z = true;
            }
        }
        return z;
    }

    private String extractTableName(TableScanNode tableScanNode) {
        String fullyQualifiedName = tableScanNode.getTable().getFullyQualifiedName();
        return fullyQualifiedName.substring(fullyQualifiedName.lastIndexOf(".") + 1);
    }

    private boolean tableNameInUniqueColumnMap(String str) {
        return uniqueColumnsPerTable.containsKey(str);
    }

    private boolean verifyPresenceOfJoinsAndGroupBy() {
        boolean verifyIfJoinNodeInPath = verifyIfJoinNodeInPath(outerTablePathStack);
        boolean verifyIfJoinNodeInPath2 = verifyIfJoinNodeInPath(innerTablePathStack);
        if ((verifyIfJoinNodeInPath && verifyIfJoinNodeInPath2) || verifyIfGroupByInPath(outerTablePathStack)) {
            return false;
        }
        return verifyIfGroupByInPath(innerTablePathStack);
    }

    private boolean verifyIfJoinNodeInPath(Stack<NodeWithTreeDirection> stack) {
        boolean z = false;
        Iterator<NodeWithTreeDirection> it = stack.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            if (it.next().getNode() instanceof JoinNode) {
                z = true;
                break;
            }
        }
        return z;
    }

    private boolean verifyIfGroupByInPath(Stack<NodeWithTreeDirection> stack) {
        boolean z = false;
        Iterator<NodeWithTreeDirection> it = stack.iterator();
        while (it.hasNext()) {
            AggregationNode node = it.next().getNode();
            if ((node instanceof AggregationNode) && !node.getGroupingKeys().isEmpty()) {
                z = true;
            }
        }
        return z;
    }

    private PlanNode resolveNodeFromGroupReference(PlanNode planNode, int i, Lookup lookup) {
        Preconditions.checkState((!(planNode instanceof JoinNode) && i == 0) || ((planNode instanceof JoinNode) && (i == 0 || i == 1)), "Attempt to access non-existing source of PlanNode");
        Preconditions.checkState(lookup.resolveGroup((PlanNode) planNode.getSources().get(i)).findFirst().isPresent(), "Attempt to resolve GroupReference when it doesn't exist");
        return lookup.resolveGroup((PlanNode) planNode.getSources().get(i)).findFirst().get();
    }

    private PlanNode planRewriter(JoinNode joinNode, Lookup lookup) {
        return updateOuterTableAndInnerTablePath(joinNode, outerTablePathStack, lookup);
    }

    private PlanNode updateOuterTableAndInnerTablePath(JoinNode joinNode, Stack<NodeWithTreeDirection> stack, Lookup lookup) {
        if (!verifyIfJoinNodeInPath(stack)) {
            return updateInnerTable(joinNode, joinNode, innerTablePathStack, false);
        }
        Stack stack2 = new Stack();
        while (!(stack.peek().getNode() instanceof JoinNode)) {
            stack2.push(stack.pop());
        }
        JoinNode joinNode2 = (JoinNode) stack.peek().getNode();
        PlanNode node = ((NodeWithTreeDirection) stack2.peek()).getNode();
        ImmutableList build = ImmutableList.builder().addAll(joinNode.getOutputSymbols()).build();
        ImmutableList build2 = outerTableDirection == DIRECTION.LEFT ? ImmutableList.builder().addAll(node.getOutputSymbols()).addAll(joinNode.getRight().getOutputSymbols()).build() : ImmutableList.builder().addAll(joinNode.getLeft().getOutputSymbols()).addAll(node.getOutputSymbols()).build();
        List<JoinNode.EquiJoinClause> newInnerJoinCriteria = getNewInnerJoinCriteria(joinNode2, joinNode);
        List<JoinNode.EquiJoinClause> newOuterJoinCriteria = getNewOuterJoinCriteria(joinNode2, joinNode);
        JoinNode joinNode3 = needNewInnerJoinFilter(joinNode, node) ? new JoinNode(joinNode.getId(), joinNode.getType(), node, resolveNodeFromGroupReference(joinNode, 1, lookup), newInnerJoinCriteria, build2, joinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), joinNode.getDynamicFilters()) : new JoinNode(joinNode.getId(), joinNode.getType(), node, resolveNodeFromGroupReference(joinNode, 1, lookup), newInnerJoinCriteria, build2, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), joinNode.getDynamicFilters());
        return updateInnerTable(innerTableDirection == DIRECTION.LEFT ? new JoinNode(joinNode2.getId(), joinNode2.getType(), joinNode3, joinNode2.getRight(), newOuterJoinCriteria, build, joinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), joinNode2.getDynamicFilters()) : new JoinNode(joinNode2.getId(), joinNode2.getType(), joinNode2.getLeft(), joinNode3, newOuterJoinCriteria, build, joinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), joinNode2.getDynamicFilters()), joinNode3, innerTablePathStack, true);
    }

    private List<JoinNode.EquiJoinClause> getNewInnerJoinCriteria(JoinNode joinNode, JoinNode joinNode2) {
        JoinNode.EquiJoinClause equiJoinClause;
        Symbol left = innerTableDirection == DIRECTION.LEFT ? ((JoinNode.EquiJoinClause) joinNode.getCriteria().get(0)).getLeft() : ((JoinNode.EquiJoinClause) joinNode.getCriteria().get(0)).getRight();
        if (outerTableDirection == DIRECTION.LEFT) {
            equiJoinClause = new JoinNode.EquiJoinClause(left, ((JoinNode.EquiJoinClause) joinNode2.getCriteria().get(0)).getRight());
        } else {
            Symbol left2 = ((JoinNode.EquiJoinClause) joinNode2.getCriteria().get(0)).getLeft();
            equiJoinClause = new JoinNode.EquiJoinClause(left2, left2);
        }
        return ImmutableList.of(equiJoinClause);
    }

    private List<JoinNode.EquiJoinClause> getNewOuterJoinCriteria(JoinNode joinNode, JoinNode joinNode2) {
        Symbol right = outerTableDirection == DIRECTION.LEFT ? ((JoinNode.EquiJoinClause) joinNode2.getCriteria().get(0)).getRight() : ((JoinNode.EquiJoinClause) joinNode2.getCriteria().get(0)).getLeft();
        return ImmutableList.of(innerTableDirection == DIRECTION.LEFT ? new JoinNode.EquiJoinClause(right, ((JoinNode.EquiJoinClause) joinNode.getCriteria().get(0)).getRight()) : new JoinNode.EquiJoinClause(((JoinNode.EquiJoinClause) joinNode.getCriteria().get(0)).getLeft(), right));
    }

    private boolean needNewInnerJoinFilter(JoinNode joinNode, PlanNode planNode) {
        if (!joinNode.getFilter().isPresent()) {
            return false;
        }
        ComparisonExpression castToExpression = OriginalExpressionUtils.castToExpression((RowExpression) joinNode.getFilter().get());
        List<Symbol> outputSymbols = planNode.getOutputSymbols();
        SymbolReference left = castToExpression.getLeft();
        SymbolReference right = castToExpression.getRight();
        for (Symbol symbol : outputSymbols) {
            if (symbol.toString().equalsIgnoreCase(left.toString()) || symbol.toString().equalsIgnoreCase(right.toString())) {
                return true;
            }
        }
        return false;
    }

    private PlanNode updateInnerTable(JoinNode joinNode, JoinNode joinNode2, Stack<NodeWithTreeDirection> stack, boolean z) {
        if (!z) {
            return getNewIntermediateTreeAfterInnerTableUpdate(joinNode2, stack);
        }
        return new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), getNewIntermediateTreeAfterInnerTableUpdate(joinNode2, stack), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), joinNode.getDynamicFilters());
    }

    private PlanNode getNewIntermediateTreeAfterInnerTableUpdate(JoinNode joinNode, Stack<NodeWithTreeDirection> stack) {
        JoinNode joinNode2;
        TableScanNode node = stack.peek().getNode();
        Assignments.Builder builder = Assignments.builder();
        Iterator it = node.getAssignments().entrySet().iterator();
        while (it.hasNext()) {
            Symbol symbol = (Symbol) ((Map.Entry) it.next()).getKey();
            builder.put(symbol, OriginalExpressionUtils.castToRowExpression(new SymbolReference(symbol.getName())));
        }
        PlanNode projectNode = new ProjectNode(this.ruleContext.getIdAllocator().getNextId(), node, builder.build());
        List outputSymbols = projectNode.getOutputSymbols();
        PlanNodeStatsEstimate stats = this.ruleContext.getStatsProvider().getStats(joinNode.getLeft());
        PlanNodeStatsEstimate stats2 = this.ruleContext.getStatsProvider().getStats(projectNode);
        if (stats.isOutputRowCountUnknown()) {
            joinNode2 = new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), projectNode, joinNode.getCriteria(), ImmutableList.builder().addAll(outputSymbols).build(), joinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), joinNode.getDynamicFilters());
        } else if (stats.getOutputRowCount() <= stats2.getOutputRowCount()) {
            joinNode2 = new JoinNode(joinNode.getId(), joinNode.getType(), projectNode, joinNode.getLeft(), (List) joinNode.getCriteria().stream().map((v0) -> {
                return v0.flip();
            }).collect(ImmutableList.toImmutableList()), ImmutableList.builder().addAll(outputSymbols).build(), joinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), joinNode.getDynamicFilters());
        } else {
            joinNode2 = new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), projectNode, joinNode.getCriteria(), ImmutableList.builder().addAll(outputSymbols).build(), joinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), joinNode.getDynamicFilters());
        }
        stack.pop();
        return stack.firstElement().getNode().replaceChildren(ImmutableList.of(stack.peek().getNode().replaceChildren(ImmutableList.of(joinNode2))));
    }

    private void clearAllDataStructures() {
        outerTablePathStack.clear();
        innerTablePathStack.clear();
        uniqueColumnsPerTable.clear();
    }
}
