package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.spark-project.guava.collect.ImmutableList;
import org.spark-project.guava.collect.Lists;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveExpandDistinctAggregatesRule.class */
public final class HiveExpandDistinctAggregatesRule extends RelOptRule {
    public static final HiveExpandDistinctAggregatesRule INSTANCE = new HiveExpandDistinctAggregatesRule(HiveAggregate.class, HiveProject.DEFAULT_PROJECT_FACTORY);
    private static RelFactories.ProjectFactory projFactory;

    public HiveExpandDistinctAggregatesRule(Class<? extends Aggregate> cls, RelFactories.ProjectFactory projectFactory) {
        super(operand(cls, any()));
        projFactory = projectFactory;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        if (aggregate.containsDistinctCall()) {
            int i = 0;
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
                if (aggregateCall.isDistinct()) {
                    ArrayList arrayList = new ArrayList();
                    Iterator it = aggregateCall.getArgList().iterator();
                    while (it.hasNext()) {
                        arrayList.add((Integer) it.next());
                    }
                    linkedHashSet.add(arrayList);
                } else {
                    i++;
                }
            }
            Util.permAssert(linkedHashSet.size() > 0, "containsDistinctCall lied");
            if (i == 0 && linkedHashSet.size() == 1) {
                Iterator it2 = ((List) linkedHashSet.iterator().next()).iterator();
                while (it2.hasNext()) {
                    Set<RelColumnOrigin> columnOrigins = RelMetadataQuery.getColumnOrigins(aggregate, ((Integer) it2.next()).intValue());
                    if (null != columnOrigins) {
                        for (RelColumnOrigin relColumnOrigin : columnOrigins) {
                            if (relColumnOrigin.getOriginTable().getPartColInfoMap().containsKey(Integer.valueOf(relColumnOrigin.getOriginColumnOrdinal()))) {
                                return;
                            }
                        }
                    }
                }
                relOptRuleCall.transformTo(convertMonopole(aggregate, (List) linkedHashSet.iterator().next()));
            }
        }
    }

    private RelNode convertMonopole(Aggregate aggregate, List<Integer> list) {
        HashMap hashMap = new HashMap();
        Aggregate createSelectDistinct = createSelectDistinct(aggregate, list, hashMap);
        ArrayList newArrayList = Lists.newArrayList(aggregate.getAggCallList());
        rewriteAggCalls(newArrayList, list, hashMap);
        return aggregate.copy(aggregate.getTraitSet(), createSelectDistinct, aggregate.indicator, ImmutableBitSet.range(aggregate.getGroupSet().cardinality()), (List) null, newArrayList);
    }

    private static void rewriteAggCalls(List<AggregateCall> list, List<Integer> list2, Map<Integer, Integer> map) {
        for (int i = 0; i < list.size(); i++) {
            AggregateCall aggregateCall = list.get(i);
            if (aggregateCall.isDistinct() && aggregateCall.getArgList().equals(list2)) {
                int size = aggregateCall.getArgList().size();
                ArrayList arrayList = new ArrayList(size);
                for (int i2 = 0; i2 < size; i2++) {
                    arrayList.add(map.get((Integer) aggregateCall.getArgList().get(i2)));
                }
                list.set(i, new AggregateCall(aggregateCall.getAggregation(), false, arrayList, aggregateCall.getType(), aggregateCall.getName()));
            }
        }
    }

    private static Aggregate createSelectDistinct(Aggregate aggregate, List<Integer> list, Map<Integer, Integer> map) {
        ArrayList arrayList = new ArrayList();
        RelNode input = aggregate.getInput();
        List fieldList = input.getRowType().getFieldList();
        Iterator it = aggregate.getGroupSet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            map.put(Integer.valueOf(intValue), Integer.valueOf(arrayList.size()));
            arrayList.add(RexInputRef.of2(intValue, fieldList));
        }
        for (Integer num : list) {
            if (map.get(num) == null) {
                map.put(num, Integer.valueOf(arrayList.size()));
                arrayList.add(RexInputRef.of2(num.intValue(), fieldList));
            }
        }
        return aggregate.copy(aggregate.getTraitSet(), projFactory.createProject(input, Pair.left(arrayList), Pair.right(arrayList)), false, ImmutableBitSet.range(arrayList.size()), (List) null, ImmutableList.of());
    }
}
