package io.prestosql.cost;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import com.google.common.collect.UnmodifiableIterator;
import io.prestosql.Session;
import io.prestosql.cost.PlanNodeStatsEstimate;
import io.prestosql.expressions.LogicalRowExpressions;
import io.prestosql.matching.Pattern;
import io.prestosql.spi.plan.JoinNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.relation.RowExpression;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.SymbolUtils;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.util.MoreMath;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/prestosql/cost/JoinStatsRule.class */
public class JoinStatsRule extends SimpleStatsRule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();
    private static final double DEFAULT_UNMATCHED_JOIN_COMPLEMENT_NDVS_COEFFICIENT = 0.5d;
    private final FilterStatsCalculator filterStatsCalculator;
    private final StatsNormalizer normalizer;
    private final double unmatchedJoinComplementNdvsCoefficient;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.prestosql.cost.JoinStatsRule$1, reason: invalid class name */
    /* loaded from: input_file:io/prestosql/cost/JoinStatsRule$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$prestosql$spi$plan$JoinNode$Type = new int[JoinNode.Type.values().length];

        static {
            try {
                $SwitchMap$io$prestosql$spi$plan$JoinNode$Type[JoinNode.Type.INNER.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$prestosql$spi$plan$JoinNode$Type[JoinNode.Type.LEFT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$io$prestosql$spi$plan$JoinNode$Type[JoinNode.Type.RIGHT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$io$prestosql$spi$plan$JoinNode$Type[JoinNode.Type.FULL.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    public JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer statsNormalizer) {
        this(filterStatsCalculator, statsNormalizer, DEFAULT_UNMATCHED_JOIN_COMPLEMENT_NDVS_COEFFICIENT);
    }

    @VisibleForTesting
    JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer statsNormalizer, double d) {
        super(statsNormalizer);
        this.filterStatsCalculator = (FilterStatsCalculator) Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator is null");
        this.normalizer = statsNormalizer;
        this.unmatchedJoinComplementNdvsCoefficient = d;
    }

    @Override // io.prestosql.cost.ComposableStatsCalculator.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.prestosql.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(JoinNode joinNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider) {
        PlanNodeStatsEstimate stats = statsProvider.getStats(joinNode.getLeft());
        PlanNodeStatsEstimate stats2 = statsProvider.getStats(joinNode.getRight());
        PlanNodeStatsEstimate crossJoinStats = crossJoinStats(joinNode, stats, stats2, typeProvider);
        switch (AnonymousClass1.$SwitchMap$io$prestosql$spi$plan$JoinNode$Type[joinNode.getType().ordinal()]) {
            case 1:
                return Optional.of(computeInnerJoinStats(joinNode, crossJoinStats, session, typeProvider));
            case 2:
                return Optional.of(computeLeftJoinStats(joinNode, stats, stats2, crossJoinStats, session, typeProvider));
            case 3:
                return Optional.of(computeRightJoinStats(joinNode, stats, stats2, crossJoinStats, session, typeProvider));
            case 4:
                return Optional.of(computeFullJoinStats(joinNode, stats, stats2, crossJoinStats, session, typeProvider));
            default:
                throw new IllegalStateException("Unknown join type: " + joinNode.getType());
        }
    }

    private PlanNodeStatsEstimate computeFullJoinStats(JoinNode joinNode, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, PlanNodeStatsEstimate planNodeStatsEstimate3, Session session, TypeProvider typeProvider) {
        return addJoinComplementStats(planNodeStatsEstimate2, computeLeftJoinStats(joinNode, planNodeStatsEstimate, planNodeStatsEstimate2, planNodeStatsEstimate3, session, typeProvider), calculateJoinComplementStats(joinNode.getFilter(), flippedCriteria(joinNode), planNodeStatsEstimate2, planNodeStatsEstimate, typeProvider));
    }

    private PlanNodeStatsEstimate computeLeftJoinStats(JoinNode joinNode, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, PlanNodeStatsEstimate planNodeStatsEstimate3, Session session, TypeProvider typeProvider) {
        return addJoinComplementStats(planNodeStatsEstimate, computeInnerJoinStats(joinNode, planNodeStatsEstimate3, session, typeProvider), calculateJoinComplementStats(joinNode.getFilter(), joinNode.getCriteria(), planNodeStatsEstimate, planNodeStatsEstimate2, typeProvider));
    }

    private PlanNodeStatsEstimate computeRightJoinStats(JoinNode joinNode, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, PlanNodeStatsEstimate planNodeStatsEstimate3, Session session, TypeProvider typeProvider) {
        return addJoinComplementStats(planNodeStatsEstimate2, computeInnerJoinStats(joinNode, planNodeStatsEstimate3, session, typeProvider), calculateJoinComplementStats(joinNode.getFilter(), flippedCriteria(joinNode), planNodeStatsEstimate2, planNodeStatsEstimate, typeProvider));
    }

    private PlanNodeStatsEstimate computeInnerJoinStats(JoinNode joinNode, PlanNodeStatsEstimate planNodeStatsEstimate, Session session, TypeProvider typeProvider) {
        List criteria = joinNode.getCriteria();
        HashMap hashMap = new HashMap();
        int i = 0;
        Iterator it = joinNode.getOutputSymbols().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            hashMap.put(Integer.valueOf(i2), (Symbol) it.next());
        }
        if (criteria.isEmpty()) {
            return !joinNode.getFilter().isPresent() ? planNodeStatsEstimate : OriginalExpressionUtils.isExpression((RowExpression) joinNode.getFilter().get()) ? this.filterStatsCalculator.filterStats(planNodeStatsEstimate, OriginalExpressionUtils.castToExpression((RowExpression) joinNode.getFilter().get()), session, typeProvider) : this.filterStatsCalculator.filterStats(planNodeStatsEstimate, (RowExpression) joinNode.getFilter().get(), session, typeProvider, hashMap);
        }
        PlanNodeStatsEstimate filterByEquiJoinClauses = filterByEquiJoinClauses(planNodeStatsEstimate, joinNode.getCriteria(), session, typeProvider);
        if (filterByEquiJoinClauses.isOutputRowCountUnknown()) {
            return PlanNodeStatsEstimate.unknown();
        }
        if (!joinNode.getFilter().isPresent()) {
            return filterByEquiJoinClauses;
        }
        PlanNodeStatsEstimate filterStats = OriginalExpressionUtils.isExpression((RowExpression) joinNode.getFilter().get()) ? this.filterStatsCalculator.filterStats(filterByEquiJoinClauses, OriginalExpressionUtils.castToExpression((RowExpression) joinNode.getFilter().get()), session, typeProvider) : this.filterStatsCalculator.filterStats(filterByEquiJoinClauses, (RowExpression) joinNode.getFilter().get(), session, typeProvider, hashMap);
        return filterStats.isOutputRowCountUnknown() ? this.normalizer.normalize(filterByEquiJoinClauses.mapOutputRowCount(d -> {
            return Double.valueOf(d.doubleValue() * 0.9d);
        }), typeProvider) : filterStats;
    }

    private PlanNodeStatsEstimate filterByEquiJoinClauses(PlanNodeStatsEstimate planNodeStatsEstimate, Collection<JoinNode.EquiJoinClause> collection, Session session, TypeProvider typeProvider) {
        Preconditions.checkArgument(!collection.isEmpty(), "clauses is empty");
        PlanNodeStatsEstimate unknown = PlanNodeStatsEstimate.unknown();
        LinkedList linkedList = new LinkedList(collection);
        JoinNode.EquiJoinClause equiJoinClause = (JoinNode.EquiJoinClause) linkedList.poll();
        for (int i = 0; i < collection.size(); i++) {
            PlanNodeStatsEstimate filterByEquiJoinClauses = filterByEquiJoinClauses(planNodeStatsEstimate, equiJoinClause, linkedList, session, typeProvider);
            if (unknown.isOutputRowCountUnknown() || (!filterByEquiJoinClauses.isOutputRowCountUnknown() && filterByEquiJoinClauses.getOutputRowCount() < unknown.getOutputRowCount())) {
                unknown = filterByEquiJoinClauses;
            }
            linkedList.add(equiJoinClause);
            equiJoinClause = (JoinNode.EquiJoinClause) linkedList.poll();
        }
        return unknown;
    }

    private PlanNodeStatsEstimate filterByEquiJoinClauses(PlanNodeStatsEstimate planNodeStatsEstimate, JoinNode.EquiJoinClause equiJoinClause, Collection<JoinNode.EquiJoinClause> collection, Session session, TypeProvider typeProvider) {
        PlanNodeStatsEstimate filterStats = this.filterStatsCalculator.filterStats(planNodeStatsEstimate, new ComparisonExpression(ComparisonExpression.Operator.EQUAL, SymbolUtils.toSymbolReference(equiJoinClause.getLeft()), SymbolUtils.toSymbolReference(equiJoinClause.getRight())), session, typeProvider);
        Iterator<JoinNode.EquiJoinClause> it = collection.iterator();
        while (it.hasNext()) {
            filterStats = filterByAuxiliaryClause(filterStats, it.next(), typeProvider);
        }
        return filterStats;
    }

    private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate planNodeStatsEstimate, JoinNode.EquiJoinClause equiJoinClause, TypeProvider typeProvider) {
        SymbolStatsEstimate symbolStatistics = planNodeStatsEstimate.getSymbolStatistics(equiJoinClause.getLeft());
        SymbolStatsEstimate symbolStatistics2 = planNodeStatsEstimate.getSymbolStatistics(equiJoinClause.getRight());
        StatisticRange from = StatisticRange.from(symbolStatistics);
        StatisticRange from2 = StatisticRange.from(symbolStatistics2);
        StatisticRange intersect = from.intersect(from2);
        double min = MoreMath.min(firstNonNaN(from.overlapPercentWith(intersect), 1.0d) * from.getDistinctValuesCount(), firstNonNaN(from2.overlapPercentWith(intersect), 1.0d) * from2.getDistinctValuesCount());
        return this.normalizer.normalize(PlanNodeStatsEstimate.buildFrom(planNodeStatsEstimate).setOutputRowCount(planNodeStatsEstimate.getOutputRowCount() * 0.9d).addSymbolStatistics(equiJoinClause.getLeft(), SymbolStatsEstimate.buildFrom(symbolStatistics).setNullsFraction(0.0d).setStatisticsRange(intersect).setDistinctValuesCount(min).build()).addSymbolStatistics(equiJoinClause.getRight(), SymbolStatsEstimate.buildFrom(symbolStatistics2).setNullsFraction(0.0d).setStatisticsRange(intersect).setDistinctValuesCount(min).build()).build(), typeProvider);
    }

    private static double firstNonNaN(double... dArr) {
        for (double d : dArr) {
            if (!Double.isNaN(d)) {
                return d;
            }
        }
        throw new IllegalArgumentException("All values are NaN");
    }

    @VisibleForTesting
    PlanNodeStatsEstimate calculateJoinComplementStats(Optional<RowExpression> optional, List<JoinNode.EquiJoinClause> list, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, TypeProvider typeProvider) {
        if (planNodeStatsEstimate2.getOutputRowCount() == 0.0d) {
            return planNodeStatsEstimate;
        }
        if (list.isEmpty()) {
            return optional.isPresent() ? PlanNodeStatsEstimate.unknown() : this.normalizer.normalize(planNodeStatsEstimate.mapOutputRowCount(d -> {
                return Double.valueOf(0.0d);
            }), typeProvider);
        }
        int size = optional.isPresent() ? OriginalExpressionUtils.isExpression(optional.get()) ? ExpressionUtils.extractConjuncts(OriginalExpressionUtils.castToExpression(optional.get())).size() : LogicalRowExpressions.extractConjuncts(optional.get()).size() : 0;
        return (PlanNodeStatsEstimate) list.stream().map(equiJoinClause -> {
            return calculateJoinComplementStats(planNodeStatsEstimate, planNodeStatsEstimate2, equiJoinClause, (list.size() - 1) + size);
        }).filter(planNodeStatsEstimate3 -> {
            return !planNodeStatsEstimate3.isOutputRowCountUnknown();
        }).max(Comparator.comparingDouble((v0) -> {
            return v0.getOutputRowCount();
        })).map(planNodeStatsEstimate4 -> {
            return this.normalizer.normalize(planNodeStatsEstimate4, typeProvider);
        }).orElse(PlanNodeStatsEstimate.unknown());
    }

    private PlanNodeStatsEstimate calculateJoinComplementStats(PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, JoinNode.EquiJoinClause equiJoinClause, int i) {
        PlanNodeStatsEstimate mapOutputRowCount;
        SymbolStatsEstimate symbolStatistics = planNodeStatsEstimate.getSymbolStatistics(equiJoinClause.getLeft());
        SymbolStatsEstimate symbolStatistics2 = planNodeStatsEstimate2.getSymbolStatistics(equiJoinClause.getRight());
        double distinctValuesCount = symbolStatistics.getDistinctValuesCount();
        double distinctValuesCount2 = symbolStatistics2.getDistinctValuesCount() * this.unmatchedJoinComplementNdvsCoefficient;
        if (distinctValuesCount > distinctValuesCount2) {
            double valuesFraction = ((symbolStatistics.getValuesFraction() * (distinctValuesCount - distinctValuesCount2)) / distinctValuesCount) + symbolStatistics.getNullsFraction();
            double nullsFraction = symbolStatistics.getNullsFraction() / valuesFraction;
            mapOutputRowCount = planNodeStatsEstimate.mapSymbolColumnStatistics(equiJoinClause.getLeft(), symbolStatsEstimate -> {
                return SymbolStatsEstimate.buildFrom(symbolStatsEstimate).setLowValue(symbolStatistics.getLowValue()).setHighValue(symbolStatistics.getHighValue()).setNullsFraction(nullsFraction).setDistinctValuesCount(distinctValuesCount - distinctValuesCount2).build();
            }).mapOutputRowCount(d -> {
                return Double.valueOf(d.doubleValue() * valuesFraction);
            });
        } else {
            if (distinctValuesCount > distinctValuesCount2) {
                return PlanNodeStatsEstimate.unknown();
            }
            mapOutputRowCount = planNodeStatsEstimate.mapSymbolColumnStatistics(equiJoinClause.getLeft(), symbolStatsEstimate2 -> {
                return SymbolStatsEstimate.buildFrom(symbolStatsEstimate2).setLowValue(Double.NaN).setHighValue(Double.NaN).setNullsFraction(1.0d).setDistinctValuesCount(0.0d).build();
            }).mapOutputRowCount(d2 -> {
                return Double.valueOf(d2.doubleValue() * symbolStatistics.getNullsFraction());
            });
        }
        return mapOutputRowCount.mapOutputRowCount(d3 -> {
            return Double.valueOf(Math.min(planNodeStatsEstimate.getOutputRowCount(), d3.doubleValue() / Math.pow(0.9d, i)));
        });
    }

    @VisibleForTesting
    PlanNodeStatsEstimate addJoinComplementStats(PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, PlanNodeStatsEstimate planNodeStatsEstimate3) {
        double outputRowCount = planNodeStatsEstimate2.getOutputRowCount();
        double outputRowCount2 = planNodeStatsEstimate3.getOutputRowCount();
        if (outputRowCount2 == 0.0d) {
            return planNodeStatsEstimate2;
        }
        double d = outputRowCount + outputRowCount2;
        PlanNodeStatsEstimate.Builder buildFrom = PlanNodeStatsEstimate.buildFrom(planNodeStatsEstimate2);
        buildFrom.setOutputRowCount(d);
        for (Symbol symbol : planNodeStatsEstimate3.getSymbolsWithKnownStatistics()) {
            SymbolStatsEstimate symbolStatistics = planNodeStatsEstimate.getSymbolStatistics(symbol);
            SymbolStatsEstimate symbolStatistics2 = planNodeStatsEstimate2.getSymbolStatistics(symbol);
            buildFrom.addSymbolStatistics(symbol, SymbolStatsEstimate.buildFrom(symbolStatistics2).setLowValue(symbolStatistics.getLowValue()).setHighValue(symbolStatistics.getHighValue()).setDistinctValuesCount(symbolStatistics.getDistinctValuesCount()).setNullsFraction(((symbolStatistics2.getNullsFraction() * outputRowCount) + (planNodeStatsEstimate3.getSymbolStatistics(symbol).getNullsFraction() * outputRowCount2)) / d).build());
        }
        UnmodifiableIterator it = Sets.difference(planNodeStatsEstimate2.getSymbolsWithKnownStatistics(), planNodeStatsEstimate3.getSymbolsWithKnownStatistics()).iterator();
        while (it.hasNext()) {
            Symbol symbol2 = (Symbol) it.next();
            SymbolStatsEstimate symbolStatistics3 = planNodeStatsEstimate2.getSymbolStatistics(symbol2);
            double nullsFraction = ((symbolStatistics3.getNullsFraction() * outputRowCount) + outputRowCount2) / d;
            buildFrom.addSymbolStatistics(symbol2, symbolStatistics3.mapNullsFraction(d2 -> {
                return Double.valueOf(nullsFraction);
            }));
        }
        return buildFrom.build();
    }

    private PlanNodeStatsEstimate crossJoinStats(JoinNode joinNode, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate planNodeStatsEstimate2, TypeProvider typeProvider) {
        PlanNodeStatsEstimate.Builder outputRowCount = PlanNodeStatsEstimate.builder().setOutputRowCount(planNodeStatsEstimate.getOutputRowCount() * planNodeStatsEstimate2.getOutputRowCount());
        joinNode.getLeft().getOutputSymbols().forEach(symbol -> {
            outputRowCount.addSymbolStatistics(symbol, planNodeStatsEstimate.getSymbolStatistics(symbol));
        });
        joinNode.getRight().getOutputSymbols().forEach(symbol2 -> {
            outputRowCount.addSymbolStatistics(symbol2, planNodeStatsEstimate2.getSymbolStatistics(symbol2));
        });
        return this.normalizer.normalize(outputRowCount.build(), typeProvider);
    }

    private List<JoinNode.EquiJoinClause> flippedCriteria(JoinNode joinNode) {
        return (List) joinNode.getCriteria().stream().map((v0) -> {
            return v0.flip();
        }).collect(ImmutableList.toImmutableList());
    }
}
