package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.spi.plan.Assignments;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.plan.UnionNode;
import io.prestosql.spi.relation.RowExpression;
import io.prestosql.spi.relation.VariableReferenceExpression;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.ExpressionSymbolInliner;
import io.prestosql.sql.planner.RowExpressionVariableInliner;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.SetOperationNodeUtils;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.SymbolReference;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/PushProjectionThroughUnion.class */
public class PushProjectionThroughUnion implements Rule<ProjectNode> {
    private static final Capture<UnionNode> CHILD = Capture.newCapture();
    private static final Pattern<ProjectNode> PATTERN = Patterns.project().matching(PushProjectionThroughUnion::nonTrivialProjection).with(Patterns.source().matching(Patterns.union().capturedAs(CHILD)));

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<ProjectNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
        RowExpression inlineVariables;
        Symbol newSymbol;
        UnionNode unionNode = (UnionNode) captures.get(CHILD);
        List outputSymbols = projectNode.getOutputSymbols();
        ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
        ImmutableList.Builder builder2 = ImmutableList.builder();
        for (int i = 0; i < unionNode.getSources().size(); i++) {
            Map<Symbol, SymbolReference> sourceSymbolMap = SetOperationNodeUtils.sourceSymbolMap(unionNode, i);
            Assignments.Builder builder3 = Assignments.builder();
            HashMap hashMap = new HashMap();
            for (Map.Entry entry : projectNode.getAssignments().entrySet()) {
                Type type = context.getSymbolAllocator().getTypes().get((Symbol) entry.getKey());
                if (OriginalExpressionUtils.isExpression((RowExpression) entry.getValue())) {
                    inlineVariables = OriginalExpressionUtils.castToRowExpression(ExpressionSymbolInliner.inlineSymbols((Map<Symbol, ? extends Expression>) sourceSymbolMap, OriginalExpressionUtils.castToExpression((RowExpression) entry.getValue())));
                    newSymbol = context.getSymbolAllocator().newSymbol(OriginalExpressionUtils.castToExpression(inlineVariables), type);
                } else {
                    HashMap hashMap2 = new HashMap();
                    Map<Symbol, Type> symbols = context.getSymbolAllocator().getSymbols();
                    for (Symbol symbol : unionNode.getSymbolMapping().keySet()) {
                        Symbol symbol2 = (Symbol) unionNode.getSymbolMapping().get(symbol).get(i);
                        hashMap2.put(new VariableReferenceExpression(symbol.getName(), symbols.get(symbol)), new VariableReferenceExpression(symbol2.getName(), symbols.get(symbol2)));
                    }
                    inlineVariables = RowExpressionVariableInliner.inlineVariables(hashMap2, (RowExpression) entry.getValue());
                    newSymbol = context.getSymbolAllocator().newSymbol(inlineVariables);
                }
                Symbol symbol3 = newSymbol;
                builder3.put(symbol3, inlineVariables);
                hashMap.put(entry.getKey(), symbol3);
            }
            builder2.add(new ProjectNode(context.getIdAllocator().getNextId(), (PlanNode) unionNode.getSources().get(i), builder3.build()));
            outputSymbols.forEach(symbol4 -> {
                builder.put(symbol4, hashMap.get(symbol4));
            });
        }
        return Rule.Result.ofPlanNode(new UnionNode(projectNode.getId(), builder2.build(), builder.build(), ImmutableList.copyOf(builder.build().keySet())));
    }

    private static boolean nonTrivialProjection(ProjectNode projectNode) {
        return !projectNode.getAssignments().getExpressions().stream().allMatch(rowExpression -> {
            return OriginalExpressionUtils.isExpression(rowExpression) ? OriginalExpressionUtils.castToExpression(rowExpression) instanceof SymbolReference : rowExpression instanceof VariableReferenceExpression;
        });
    }
}
