package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.spi.plan.AggregationNode;
import io.prestosql.spi.plan.Assignments;
import io.prestosql.spi.plan.FilterNode;
import io.prestosql.spi.plan.JoinNode;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.plan.TableScanNode;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.ExpressionSymbolInliner;
import io.prestosql.sql.planner.SymbolUtils;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.PlanNodeSearcher;
import io.prestosql.sql.planner.plan.AssignmentUtils;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Expression;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Predicate;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.class */
public class TransformFilteringSemiJoinToInnerJoin implements Rule<FilterNode> {
    private static final Capture<SemiJoinNode> SEMI_JOIN = Capture.newCapture();
    private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.semiJoin().capturedAs(SEMI_JOIN)));

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isRewriteFilteringSemiJoinToInnerJoin(session);
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
        SemiJoinNode semiJoinNode = (SemiJoinNode) captures.get(SEMI_JOIN);
        if (PlanNodeSearcher.searchFrom(semiJoinNode.getSource(), context.getLookup()).where(planNode -> {
            return (planNode instanceof TableScanNode) && ((TableScanNode) planNode).isForDelete();
        }).matches()) {
            return Rule.Result.empty();
        }
        Symbol semiJoinOutput = semiJoinNode.getSemiJoinOutput();
        Predicate<? super Expression> predicate = expression -> {
            return expression.equals(SymbolUtils.toSymbolReference(semiJoinOutput));
        };
        List<Expression> extractConjuncts = ExpressionUtils.extractConjuncts(OriginalExpressionUtils.castToExpression(filterNode.getPredicate()));
        if (extractConjuncts.stream().noneMatch(predicate)) {
            return Rule.Result.empty();
        }
        Expression inlineSymbols = ExpressionSymbolInliner.inlineSymbols((Function<Symbol, Expression>) symbol -> {
            return symbol.equals(semiJoinOutput) ? BooleanLiteral.TRUE_LITERAL : SymbolUtils.toSymbolReference(symbol);
        }, ExpressionUtils.and((Collection<Expression>) extractConjuncts.stream().filter(expression2 -> {
            return !expression2.equals(SymbolUtils.toSymbolReference(semiJoinOutput));
        }).collect(ImmutableList.toImmutableList())));
        Optional empty = inlineSymbols.equals(BooleanLiteral.TRUE_LITERAL) ? Optional.empty() : Optional.of(inlineSymbols);
        JoinNode joinNode = new JoinNode(semiJoinNode.getId(), JoinNode.Type.INNER, semiJoinNode.getSource(), new AggregationNode(context.getIdAllocator().getNextId(), semiJoinNode.getFilteringSource(), ImmutableMap.of(), AggregationNode.singleGroupingSet(ImmutableList.of(semiJoinNode.getFilteringSourceJoinSymbol())), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), ImmutableList.of(new JoinNode.EquiJoinClause(semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol())), semiJoinNode.getSource().getOutputSymbols(), empty.isPresent() ? Optional.of(OriginalExpressionUtils.castToRowExpression((Expression) empty.get())) : Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of());
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), joinNode, Assignments.builder().putAll(AssignmentUtils.identityAsSymbolReferences(joinNode.getOutputSymbols())).put(semiJoinOutput, OriginalExpressionUtils.castToRowExpression(BooleanLiteral.TRUE_LITERAL)).build()));
    }
}
