package io.prestosql.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.CatalogSchemaName;
import io.prestosql.spi.connector.QualifiedObjectName;
import io.prestosql.spi.plan.AggregationNode;
import io.prestosql.spi.plan.Assignments;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.PlanNodeIdAllocator;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.SetOperationNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.plan.UnionNode;
import io.prestosql.spi.relation.CallExpression;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.analyzer.TypeSignatureProvider;
import io.prestosql.sql.planner.PlanSymbolAllocator;
import io.prestosql.sql.planner.SymbolUtils;
import io.prestosql.sql.planner.optimizations.SetOperationNodeUtils;
import io.prestosql.sql.planner.optimizations.StarTreeAggregationRule;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.GenericLiteral;
import io.prestosql.sql.tree.Literal;
import io.prestosql.sql.tree.NullLiteral;
import io.prestosql.sql.tree.SymbolReference;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/SetOperationNodeTranslator.class */
public class SetOperationNodeTranslator {
    private static final String MARKER = "marker";
    private static final QualifiedObjectName COUNT_AGGREGATION_NAME = QualifiedObjectName.valueOf(CatalogSchemaName.DEFAULT_NAMESPACE, StarTreeAggregationRule.COUNT);
    private static final Literal GENERIC_LITERAL = new GenericLiteral("BIGINT", "1");
    private final PlanSymbolAllocator planSymbolAllocator;
    private final PlanNodeIdAllocator idAllocator;
    private final Metadata metadata;

    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/SetOperationNodeTranslator$TranslationResult.class */
    public static class TranslationResult {
        private final PlanNode planNode;
        private final List<Expression> presentExpressions;

        public TranslationResult(PlanNode planNode, List<Expression> list) {
            this.planNode = (PlanNode) Objects.requireNonNull(planNode, "AggregationNode is null");
            this.presentExpressions = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "AggregationOutputs is null"));
        }

        public PlanNode getPlanNode() {
            return this.planNode;
        }

        public List<Expression> getPresentExpressions() {
            return this.presentExpressions;
        }
    }

    public SetOperationNodeTranslator(Metadata metadata, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        this.planSymbolAllocator = (PlanSymbolAllocator) Objects.requireNonNull(planSymbolAllocator, "SymbolAllocator is null");
        this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "PlanNodeIdAllocator is null");
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    public TranslationResult makeSetContainmentPlan(SetOperationNode setOperationNode) {
        Preconditions.checkArgument(!(setOperationNode instanceof UnionNode), "Cannot simplify a UnionNode");
        List<Symbol> allocateSymbols = allocateSymbols(setOperationNode.getSources().size(), MARKER, BooleanType.BOOLEAN);
        List<PlanNode> appendMarkers = appendMarkers(allocateSymbols, setOperationNode.getSources(), setOperationNode);
        List<Symbol> outputSymbols = setOperationNode.getOutputSymbols();
        UnionNode union = union(appendMarkers, ImmutableList.copyOf(Iterables.concat(outputSymbols, allocateSymbols)));
        List<Symbol> allocateSymbols2 = allocateSymbols(allocateSymbols.size(), StarTreeAggregationRule.COUNT, BigintType.BIGINT);
        return new TranslationResult(computeCounts(union, outputSymbols, allocateSymbols, allocateSymbols2), (List) allocateSymbols2.stream().map(symbol -> {
            return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, SymbolUtils.toSymbolReference(symbol), GENERIC_LITERAL);
        }).collect(ImmutableList.toImmutableList()));
    }

    private List<Symbol> allocateSymbols(int i, String str, Type type) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i2 = 0; i2 < i; i2++) {
            builder.add(this.planSymbolAllocator.newSymbol(str, type));
        }
        return builder.build();
    }

    private List<PlanNode> appendMarkers(List<Symbol> list, List<PlanNode> list2, SetOperationNode setOperationNode) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i = 0; i < list2.size(); i++) {
            builder.add(appendMarkers(this.idAllocator, this.planSymbolAllocator, list2.get(i), i, list, SetOperationNodeUtils.sourceSymbolMap(setOperationNode, i)));
        }
        return builder.build();
    }

    private static PlanNode appendMarkers(PlanNodeIdAllocator planNodeIdAllocator, PlanSymbolAllocator planSymbolAllocator, PlanNode planNode, int i, List<Symbol> list, Map<Symbol, SymbolReference> map) {
        Assignments.Builder builder = Assignments.builder();
        for (Map.Entry<Symbol, SymbolReference> entry : map.entrySet()) {
            builder.put(planSymbolAllocator.newSymbol(entry.getKey().getName(), planSymbolAllocator.getTypes().get(entry.getKey())), OriginalExpressionUtils.castToRowExpression(entry.getValue()));
        }
        int i2 = 0;
        while (i2 < list.size()) {
            builder.put(planSymbolAllocator.newSymbol(list.get(i2).getName(), (Type) BooleanType.BOOLEAN), OriginalExpressionUtils.castToRowExpression(i2 == i ? BooleanLiteral.TRUE_LITERAL : new Cast(new NullLiteral(), "boolean")));
            i2++;
        }
        return new ProjectNode(planNodeIdAllocator.getNextId(), planNode, builder.build());
    }

    private UnionNode union(List<PlanNode> list, List<Symbol> list2) {
        ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
        for (PlanNode planNode : list) {
            for (int i = 0; i < planNode.getOutputSymbols().size(); i++) {
                builder.put(list2.get(i), planNode.getOutputSymbols().get(i));
            }
        }
        return new UnionNode(this.idAllocator.getNextId(), list, builder.build(), list2);
    }

    private AggregationNode computeCounts(UnionNode unionNode, List<Symbol> list, List<Symbol> list2, List<Symbol> list3) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (int i = 0; i < list2.size(); i++) {
            builder.put(list3.get(i), new AggregationNode.Aggregation(new CallExpression(COUNT_AGGREGATION_NAME.getObjectName(), this.metadata.getFunctionAndTypeManager().lookupFunction(COUNT_AGGREGATION_NAME.getObjectName(), TypeSignatureProvider.fromTypes(BooleanType.BOOLEAN)), BigintType.BIGINT, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(SymbolUtils.toSymbolReference(list2.get(i)))), Optional.empty()), ImmutableList.of(OriginalExpressionUtils.castToRowExpression(SymbolUtils.toSymbolReference(list2.get(i)))), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        return new AggregationNode(this.idAllocator.getNextId(), unionNode, builder.build(), AggregationNode.singleGroupingSet(list), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
    }
}
