package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import io.prestosql.hive.$internal.com.google.common.collect.ImmutableList;
import io.prestosql.hive.$internal.com.google.common.collect.Lists;
import io.prestosql.hive.$internal.com.google.common.collect.Maps;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.CompositeList;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction;
import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlSumAggFunction;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceFunctionsRule.class */
public class HiveAggregateReduceFunctionsRule extends RelOptRule {
    public static final HiveAggregateReduceFunctionsRule INSTANCE;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveAggregateReduceFunctionsRule$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceFunctionsRule$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$calcite$sql$SqlKind = new int[SqlKind.values().length];

        static {
            try {
                $SwitchMap$org$apache$calcite$sql$SqlKind[SqlKind.AVG.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$calcite$sql$SqlKind[SqlKind.STDDEV_POP.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$calcite$sql$SqlKind[SqlKind.STDDEV_SAMP.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$apache$calcite$sql$SqlKind[SqlKind.VAR_POP.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$apache$calcite$sql$SqlKind[SqlKind.VAR_SAMP.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    public HiveAggregateReduceFunctionsRule() {
        super(operand(HiveAggregate.class, any()), HiveRelFactories.HIVE_BUILDER, (String) null);
    }

    public boolean matches(RelOptRuleCall relOptRuleCall) {
        if (super.matches(relOptRuleCall)) {
            return containsAvgStddevVarCall(relOptRuleCall.rels[0].getAggCallList());
        }
        return false;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        reduceAggs(relOptRuleCall, (Aggregate) relOptRuleCall.rels[0]);
    }

    private boolean containsAvgStddevVarCall(List<AggregateCall> list) {
        Iterator<AggregateCall> it = list.iterator();
        while (it.hasNext()) {
            if (isReducible(it.next().getAggregation().getKind())) {
                return true;
            }
        }
        return false;
    }

    private boolean isReducible(SqlKind sqlKind) {
        return SqlKind.AVG_AGG_FUNCTIONS.contains(sqlKind);
    }

    private void reduceAggs(RelOptRuleCall relOptRuleCall, Aggregate aggregate) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        List aggCallList = aggregate.getAggCallList();
        int groupCount = aggregate.getGroupCount();
        int indicatorCount = aggregate.getIndicatorCount();
        ArrayList newArrayList = Lists.newArrayList();
        HashMap newHashMap = Maps.newHashMap();
        ArrayList newArrayList2 = Lists.newArrayList();
        for (int i = 0; i < groupCount + indicatorCount; i++) {
            newArrayList2.add(rexBuilder.makeInputRef(getFieldType(aggregate, i), i));
        }
        RelBuilder builder = relOptRuleCall.builder();
        builder.push(aggregate.getInput());
        ArrayList arrayList = new ArrayList(builder.fields());
        Iterator it = aggCallList.iterator();
        while (it.hasNext()) {
            newArrayList2.add(reduceAgg(aggregate, (AggregateCall) it.next(), newArrayList, newHashMap, arrayList));
        }
        int size = arrayList.size() - builder.peek().getRowType().getFieldCount();
        if (size > 0) {
            builder.project(arrayList, CompositeList.of(builder.peek().getRowType().getFieldNames(), Collections.nCopies(size, null)));
        }
        newAggregateRel(builder, aggregate, newArrayList);
        builder.project(newArrayList2, aggregate.getRowType().getFieldNames()).convert(aggregate.getRowType(), false);
        relOptRuleCall.transformTo(builder.build());
    }

    private RexNode reduceAgg(Aggregate aggregate, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        SqlKind kind = aggregateCall.getAggregation().getKind();
        if (!isReducible(kind)) {
            return aggregate.getCluster().getRexBuilder().addAggCall(aggregateCall, aggregate.getGroupCount(), aggregate.indicator, list, map, SqlTypeUtil.projectTypes(aggregate.getInput().getRowType(), aggregateCall.getArgList()));
        }
        switch (AnonymousClass1.$SwitchMap$org$apache$calcite$sql$SqlKind[kind.ordinal()]) {
            case 1:
                return reduceAvg(aggregate, aggregateCall, list, map, list2);
            case 2:
                return reduceStddev(aggregate, aggregateCall, true, true, list, map, list2);
            case 3:
                return reduceStddev(aggregate, aggregateCall, false, true, list, map, list2);
            case 4:
                return reduceStddev(aggregate, aggregateCall, true, false, list, map, list2);
            case 5:
                return reduceStddev(aggregate, aggregateCall, false, false, list, map, list2);
            default:
                throw Util.unexpected(kind);
        }
    }

    private AggregateCall createAggregateCallWithBinding(RelDataTypeFactory relDataTypeFactory, SqlAggFunction sqlAggFunction, RelDataType relDataType, Aggregate aggregate, AggregateCall aggregateCall, int i) {
        return AggregateCall.create(sqlAggFunction, aggregateCall.isDistinct(), aggregateCall.isApproximate(), ImmutableIntList.of(new int[]{i}), aggregateCall.filterArg, sqlAggFunction.inferReturnType(new Aggregate.AggCallBinding(relDataTypeFactory, sqlAggFunction, ImmutableList.of(relDataType), aggregate.getGroupCount(), aggregateCall.filterArg >= 0)), (String) null);
    }

    private RexNode reduceAvg(Aggregate aggregate, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        int groupCount = aggregate.getGroupCount();
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
        RelDataType createTypeWithNullability = typeFactory.createTypeWithNullability(getFieldType(aggregate.getInput(), ((Integer) aggregateCall.getArgList().get(0)).intValue()), true);
        AggregateCall create = AggregateCall.create(new HiveSqlSumAggFunction(aggregateCall.isDistinct(), aggregateCall.getAggregation().getReturnTypeInference(), aggregateCall.getAggregation().getOperandTypeInference(), aggregateCall.getAggregation().getOperandTypeChecker()), aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), (RelDataType) null, (String) null);
        RelDataType createTypeWithNullability2 = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
        AggregateCall create2 = AggregateCall.create(new HiveSqlCountAggFunction(aggregateCall.isDistinct(), ReturnTypes.explicit(createTypeWithNullability2), aggregateCall.getAggregation().getOperandTypeInference(), aggregateCall.getAggregation().getOperandTypeChecker()), aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), createTypeWithNullability2, (String) null);
        RexNode addAggCall = rexBuilder.addAggCall(create, groupCount, aggregate.indicator, list, map, ImmutableList.of(createTypeWithNullability));
        RexNode addAggCall2 = rexBuilder.addAggCall(create2, groupCount, aggregate.indicator, list, map, ImmutableList.of(createTypeWithNullability));
        return rexBuilder.makeCast(aggregateCall.getType(), rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, new RexNode[]{rexBuilder.ensureType(aggregateCall.getType(), addAggCall, true), addAggCall2}));
    }

    private RexNode reduceStddev(Aggregate aggregate, AggregateCall aggregateCall, boolean z, boolean z2, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        RexNode makeCall;
        int groupCount = aggregate.getGroupCount();
        RelOptCluster cluster = aggregate.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RelDataTypeFactory typeFactory = cluster.getTypeFactory();
        if (!$assertionsDisabled && aggregateCall.getArgList().size() != 1) {
            throw new AssertionError(aggregateCall.getArgList());
        }
        int intValue = ((Integer) aggregateCall.getArgList().get(0)).intValue();
        RelDataType fieldType = getFieldType(aggregate.getInput(), intValue);
        RelDataType createTypeWithNullability = typeFactory.createTypeWithNullability(aggregateCall.getType(), true);
        RexNode ensureType = rexBuilder.ensureType(createTypeWithNullability, list2.get(intValue), false);
        int lookupOrAdd = lookupOrAdd(list2, ensureType);
        RexNode makeCall2 = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, new RexNode[]{ensureType, ensureType});
        AggregateCall createAggregateCallWithBinding = createAggregateCallWithBinding(typeFactory, new HiveSqlSumAggFunction(aggregateCall.isDistinct(), aggregateCall.getAggregation().getReturnTypeInference(), aggregateCall.getAggregation().getOperandTypeInference(), aggregateCall.getAggregation().getOperandTypeChecker()), makeCall2.getType(), aggregate, aggregateCall, lookupOrAdd(list2, makeCall2));
        RexNode addAggCall = rexBuilder.addAggCall(createAggregateCallWithBinding, groupCount, aggregate.indicator, list, map, ImmutableList.of(createAggregateCallWithBinding.getType()));
        AggregateCall create = AggregateCall.create(new HiveSqlSumAggFunction(aggregateCall.isDistinct(), aggregateCall.getAggregation().getReturnTypeInference(), aggregateCall.getAggregation().getOperandTypeInference(), aggregateCall.getAggregation().getOperandTypeChecker()), aggregateCall.isDistinct(), aggregateCall.isApproximate(), ImmutableIntList.of(new int[]{lookupOrAdd}), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), (RelDataType) null, (String) null);
        RexNode ensureType2 = rexBuilder.ensureType(createTypeWithNullability, rexBuilder.addAggCall(create, groupCount, aggregate.indicator, list, map, ImmutableList.of(create.getType())), true);
        RexNode makeCall3 = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, new RexNode[]{ensureType2, ensureType2});
        RelDataType createTypeWithNullability2 = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
        RexNode addAggCall2 = rexBuilder.addAggCall(AggregateCall.create(new HiveSqlCountAggFunction(aggregateCall.isDistinct(), ReturnTypes.explicit(createTypeWithNullability2), aggregateCall.getAggregation().getOperandTypeInference(), aggregateCall.getAggregation().getOperandTypeChecker()), aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), createTypeWithNullability2, (String) null), groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType));
        RexNode makeCall4 = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, new RexNode[]{addAggCall, rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, new RexNode[]{makeCall3, addAggCall2})});
        if (z) {
            makeCall = addAggCall2;
        } else {
            RexNode makeExactLiteral = rexBuilder.makeExactLiteral(BigDecimal.ONE);
            RexNode makeCast = rexBuilder.makeCast(addAggCall2.getType(), rexBuilder.constantNull());
            RexNode makeCall5 = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, new RexNode[]{addAggCall2, makeExactLiteral});
            makeCall = rexBuilder.makeCall(SqlStdOperatorTable.CASE, new RexNode[]{rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, new RexNode[]{addAggCall2, makeExactLiteral}), makeCast, makeCall5});
        }
        RexNode makeCall6 = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, new RexNode[]{makeCall4, makeCall});
        RexNode rexNode = makeCall6;
        if (z2) {
            rexNode = rexBuilder.makeCall(SqlStdOperatorTable.POWER, new RexNode[]{makeCall6, rexBuilder.makeExactLiteral(new BigDecimal("0.5"))});
        }
        return rexBuilder.makeCast(aggregateCall.getType(), rexNode);
    }

    private static int lookupOrAdd(List<RexNode> list, RexNode rexNode) {
        for (int i = 0; i < list.size(); i++) {
            if (list.get(i).toString().equals(rexNode.toString())) {
                return i;
            }
        }
        list.add(rexNode);
        return list.size() - 1;
    }

    protected void newAggregateRel(RelBuilder relBuilder, Aggregate aggregate, List<AggregateCall> list) {
        relBuilder.aggregate(relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets()), list);
    }

    private RelDataType getFieldType(RelNode relNode, int i) {
        return ((RelDataTypeField) relNode.getRowType().getFieldList().get(i)).getType();
    }

    static {
        $assertionsDisabled = !HiveAggregateReduceFunctionsRule.class.desiredAssertionStatus();
        INSTANCE = new HiveAggregateReduceFunctionsRule();
    }
}
