package io.prestosql.cost;

import io.prestosql.Session;
import io.prestosql.cost.PlanNodeStatsEstimate;
import io.prestosql.matching.Pattern;
import io.prestosql.spi.plan.GroupIdNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.plan.Patterns;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/* loaded from: input_file:io/prestosql/cost/GroupIdStatsRule.class */
public class GroupIdStatsRule extends SimpleStatsRule<GroupIdNode> {
    private static final Pattern<GroupIdNode> PATTERN = Patterns.groupId();

    public GroupIdStatsRule(StatsNormalizer statsNormalizer) {
        super(statsNormalizer);
    }

    @Override // io.prestosql.cost.ComposableStatsCalculator.Rule
    public Pattern<GroupIdNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(GroupIdNode groupIdNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider) {
        return Optional.of(groupBy(statsProvider.getStats(groupIdNode.getSource()), groupIdNode.getGroupingSets(), groupIdNode.getGroupingColumns()));
    }

    public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate planNodeStatsEstimate, List<List<Symbol>> list, Map<Symbol, Symbol> map) {
        PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder();
        double d = 0.0d;
        boolean z = false;
        for (List<Symbol> list2 : list) {
            if (list2.size() == 0) {
                d = 1.0d;
            } else {
                double d2 = 1.0d;
                Iterator<Symbol> it = list2.iterator();
                while (it.hasNext()) {
                    Symbol symbol = map.get(it.next());
                    SymbolStatsEstimate symbolStatistics = planNodeStatsEstimate.getSymbolStatistics(symbol);
                    builder.addSymbolStatistics(symbol, symbolStatistics.mapNullsFraction(d3 -> {
                        return d3.doubleValue() == 0.0d ? Double.valueOf(0.0d) : Double.valueOf(1.0d / (symbolStatistics.getDistinctValuesCount() + 1.0d));
                    }));
                }
                Iterator<Symbol> it2 = list2.iterator();
                while (it2.hasNext()) {
                    SymbolStatsEstimate symbolStatistics2 = planNodeStatsEstimate.getSymbolStatistics(map.get(it2.next()));
                    if (!symbolStatistics2.isUnknown()) {
                        z = true;
                        d2 *= symbolStatistics2.getDistinctValuesCount() + (symbolStatistics2.getNullsFraction() == 0.0d ? 0 : 1);
                    }
                }
                double min = Math.min(d2, planNodeStatsEstimate.getOutputRowCount());
                d += min + (list2.size() > 1 ? planNodeStatsEstimate.getOutputRowCount() / min : 0.0d);
            }
        }
        if (z) {
            builder.setOutputRowCount(d + 1.0d);
        } else {
            builder.setOutputRowCount(planNodeStatsEstimate.getOutputRowCount() * 0.9d);
        }
        return builder.build();
    }
}
