package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.log.Logger;
import io.hetu.core.spi.cube.CubeAggregateFunction;
import io.hetu.core.spi.cube.CubeMetadata;
import io.hetu.core.spi.cube.aggregator.AggregationSignature;
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.SymbolAllocator;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnNotFoundException;
import io.prestosql.spi.connector.CubeNotFoundException;
import io.prestosql.spi.connector.QualifiedObjectName;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.metadata.TableHandle;
import io.prestosql.spi.plan.AggregationNode;
import io.prestosql.spi.plan.Assignments;
import io.prestosql.spi.plan.FilterNode;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.PlanNodeIdAllocator;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.relation.CallExpression;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.sql.analyzer.TypeSignatureProvider;
import io.prestosql.sql.planner.PlanSymbolAllocator;
import io.prestosql.sql.planner.SymbolUtils;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.optimizations.CubeRewriteResult;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.ArithmeticBinaryExpression;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.SymbolReference;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/AggregationRewriteWithCube.class */
public class AggregationRewriteWithCube {
    private static final Logger log = Logger.get(AggregationRewriteWithCube.class);
    private final Session session;
    private final SymbolAllocator symbolAllocator;
    private final Metadata metadata;
    private final PlanNodeIdAllocator idAllocator;
    private final Map<String, Object> symbolMappings;
    private final CubeMetadata cubeMetadata;
    private final TypeProvider typeProvider;

    public AggregationRewriteWithCube(Metadata metadata, Session session, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, Map<String, Object> map, CubeMetadata cubeMetadata) {
        this.session = session;
        this.symbolAllocator = planSymbolAllocator;
        this.metadata = metadata;
        this.idAllocator = planNodeIdAllocator;
        this.symbolMappings = map;
        this.cubeMetadata = cubeMetadata;
        this.typeProvider = planSymbolAllocator.getTypes();
    }

    public PlanNode rewrite(AggregationNode aggregationNode, PlanNode planNode) {
        QualifiedObjectName valueOf = QualifiedObjectName.valueOf(this.cubeMetadata.getCubeName());
        TableHandle orElseThrow = this.metadata.getTableHandle(this.session, valueOf).orElseThrow(() -> {
            return new CubeNotFoundException(valueOf.toString());
        });
        Map<String, ColumnHandle> columnHandles = this.metadata.getColumnHandles(this.session, orElseThrow);
        CubeRewriteResult createScanNode = createScanNode(aggregationNode, planNode, orElseThrow, columnHandles);
        PlanNode tableScanNode = createScanNode.getTableScanNode();
        if (planNode != null) {
            tableScanNode = new FilterNode(this.idAllocator.getNextId(), tableScanNode, OriginalExpressionUtils.castToRowExpression(rewriteSymbolReferenceUsingColumnName(OriginalExpressionUtils.castToExpression(((FilterNode) planNode).getPredicate()), this.symbolMappings)));
        }
        ArrayList arrayList = new ArrayList(aggregationNode.getGroupingKeys().size());
        Iterator it = aggregationNode.getGroupingKeys().iterator();
        while (it.hasNext()) {
            Object obj = this.symbolMappings.get(((Symbol) it.next()).getName());
            if (obj instanceof ColumnHandle) {
                arrayList.add(new Symbol(((ColumnHandle) obj).getColumnName()));
            }
        }
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (CubeRewriteResult.AggregatorSource aggregatorSource : createScanNode.getAggregationColumns()) {
            Type type = createScanNode.getSymbolMetadataMap().get(aggregatorSource.getOriginalAggSymbol()).getType();
            TypeSignature typeSignature = type.getTypeSignature();
            ColumnHandle columnHandle = (ColumnHandle) createScanNode.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
            AggregationSignature aggregationSignature = (AggregationSignature) this.cubeMetadata.getAggregationSignature(this.metadata.getColumnMetadata(this.session, orElseThrow, columnHandle).getName()).orElseThrow(() -> {
                return new ColumnNotFoundException(new SchemaTableName(valueOf.getSchemaName(), valueOf.getObjectName()), columnHandle.getColumnName());
            });
            String function = CubeAggregateFunction.COUNT.getName().equals(aggregationSignature.getFunction()) ? StarTreeAggregationRule.SUM : aggregationSignature.getFunction();
            SymbolReference symbolReference = SymbolUtils.toSymbolReference(aggregatorSource.getScanSymbol());
            builder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(function, this.metadata.getFunctionAndTypeManager().lookupFunction(function, TypeSignatureProvider.fromTypeSignatures(typeSignature)), type, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(symbolReference))), ImmutableList.of(OriginalExpressionUtils.castToRowExpression(symbolReference)), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        PlanNode aggregationNode2 = new AggregationNode(this.idAllocator.getNextId(), tableScanNode, builder.build(), AggregationNode.singleGroupingSet(arrayList), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
        if (!createScanNode.getAvgAggregationColumns().isEmpty()) {
            HashSet hashSet = new HashSet();
            createScanNode.getAvgAggregationColumns().forEach(averageAggregatorSource -> {
                hashSet.add(averageAggregatorSource.getCount());
                hashSet.add(averageAggregatorSource.getSum());
            });
            HashMap hashMap = new HashMap();
            Iterator it2 = aggregationNode.getOutputSymbols().iterator();
            while (it2.hasNext()) {
                Object obj2 = this.symbolMappings.get(((Symbol) it2.next()).getName());
                if (obj2 instanceof ColumnHandle) {
                    Symbol symbol = new Symbol(columnHandles.get(((ColumnHandle) obj2).getColumnName()).getColumnName());
                    hashSet.add(symbol);
                    hashMap.put(symbol, SymbolUtils.toSymbolReference(symbol));
                }
            }
            for (Symbol symbol2 : createScanNode.getTableScanNode().getOutputSymbols()) {
                if (!hashSet.contains(symbol2)) {
                    hashMap.put(symbol2, SymbolUtils.toSymbolReference(symbol2));
                }
            }
            for (CubeRewriteResult.AverageAggregatorSource averageAggregatorSource2 : createScanNode.getAvgAggregationColumns()) {
                hashMap.put(averageAggregatorSource2.getOriginalAggSymbol(), new Cast(new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, SymbolUtils.toSymbolReference(averageAggregatorSource2.getSum()), new Cast(SymbolUtils.toSymbolReference(averageAggregatorSource2.getCount()), createScanNode.getSymbolMetadataMap().get(averageAggregatorSource2.getSum()).getType().getTypeSignature().toString())), this.typeProvider.get(averageAggregatorSource2.getOriginalAggSymbol()).getTypeSignature().toString()));
            }
            aggregationNode2 = new ProjectNode(this.idAllocator.getNextId(), aggregationNode2, new Assignments((Map) hashMap.entrySet().stream().collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, entry -> {
                return OriginalExpressionUtils.castToRowExpression((Expression) entry.getValue());
            }))));
        }
        if (!aggregationNode2.getOutputSymbols().equals(aggregationNode.getOutputSymbols())) {
            HashMap hashMap2 = new HashMap();
            for (Symbol symbol3 : aggregationNode.getOutputSymbols()) {
                Object obj3 = this.symbolMappings.get(symbol3.getName());
                if (obj3 instanceof ColumnHandle) {
                    hashMap2.put(symbol3, new SymbolReference(columnHandles.get(((ColumnHandle) obj3).getColumnName()).getColumnName()));
                } else {
                    hashMap2.put(symbol3, SymbolUtils.toSymbolReference(symbol3));
                }
            }
            aggregationNode2 = new ProjectNode(this.idAllocator.getNextId(), aggregationNode2, new Assignments((Map) hashMap2.entrySet().stream().collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, entry2 -> {
                return OriginalExpressionUtils.castToRowExpression((Expression) entry2.getValue());
            }))));
        }
        return aggregationNode2;
    }

    private static Expression rewriteSymbolReferenceUsingColumnName(Expression expression, Map<String, Object> map) {
        return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Map<String, Object>>() { // from class: io.prestosql.sql.planner.optimizations.AggregationRewriteWithCube.1
            public Expression rewriteSymbolReference(SymbolReference symbolReference, Map<String, Object> map2, ExpressionTreeRewriter<Map<String, Object>> expressionTreeRewriter) {
                return new SymbolReference(((ColumnHandle) map2.get(symbolReference.getName())).getColumnName());
            }

            public /* bridge */ /* synthetic */ Expression rewriteSymbolReference(SymbolReference symbolReference, Object obj, ExpressionTreeRewriter expressionTreeRewriter) {
                return rewriteSymbolReference(symbolReference, (Map<String, Object>) obj, (ExpressionTreeRewriter<Map<String, Object>>) expressionTreeRewriter);
            }
        }, expression, map);
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:100:0x0687, code lost:
    
        r0.add(new io.prestosql.sql.planner.optimizations.CubeRewriteResult.AverageAggregatorSource(r0, r40, r44));
     */
    /* JADX WARN: Code restructure failed: missing block: B:102:0x063f, code lost:
    
        r0 = r0.entrySet().iterator();
     */
    /* JADX WARN: Code restructure failed: missing block: B:104:0x0654, code lost:
    
        if (r0.hasNext() == false) goto L149;
     */
    /* JADX WARN: Code restructure failed: missing block: B:105:0x0657, code lost:
    
        r0 = (java.util.Map.Entry) r0.next();
     */
    /* JADX WARN: Code restructure failed: missing block: B:106:0x0672, code lost:
    
        if (((io.prestosql.spi.connector.ColumnHandle) r0.getValue()).equals(r0) == false) goto L150;
     */
    /* JADX WARN: Code restructure failed: missing block: B:108:0x0675, code lost:
    
        r44 = (io.prestosql.spi.plan.Symbol) r0.getKey();
     */
    /* JADX WARN: Code restructure failed: missing block: B:111:0x0532, code lost:
    
        r0 = r0.entrySet().iterator();
     */
    /* JADX WARN: Code restructure failed: missing block: B:113:0x0547, code lost:
    
        if (r0.hasNext() == false) goto L152;
     */
    /* JADX WARN: Code restructure failed: missing block: B:114:0x054a, code lost:
    
        r0 = (java.util.Map.Entry) r0.next();
     */
    /* JADX WARN: Code restructure failed: missing block: B:115:0x0565, code lost:
    
        if (((io.prestosql.spi.connector.ColumnHandle) r0.getValue()).equals(r0) == false) goto L153;
     */
    /* JADX WARN: Code restructure failed: missing block: B:117:0x0568, code lost:
    
        r40 = (io.prestosql.spi.plan.Symbol) r0.getKey();
     */
    /* JADX WARN: Code restructure failed: missing block: B:122:0x06bd, code lost:
    
        throw new io.prestosql.spi.PrestoException(io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR, "Unsupported aggregation function " + r0);
     */
    /* JADX WARN: Code restructure failed: missing block: B:125:0x03e0, code lost:
    
        r0 = new io.hetu.core.spi.cube.aggregator.AggregationSignature(r0, r0, r0);
        r0 = r16.get((java.lang.String) r12.cubeMetadata.getColumn(r0).orElseThrow(() -> { // java.util.function.Supplier.get():java.lang.Object
            return lambda$createScanNode$6(r1);
        }));
     */
    /* JADX WARN: Code restructure failed: missing block: B:126:0x0420, code lost:
    
        if (r0.containsValue(r0) != false) goto L144;
     */
    /* JADX WARN: Code restructure failed: missing block: B:128:0x0423, code lost:
    
        r0 = r12.metadata.getColumnMetadata(r12.session, r15, r0);
        r0.put(r0, r0);
        r0.put(r0, r0);
        r0.add(r0);
        r0.add(new io.prestosql.sql.planner.optimizations.CubeRewriteResult.AggregatorSource(r0, r0));
     */
    /* JADX WARN: Code restructure failed: missing block: B:92:0x03bf, code lost:
    
        switch(r33) {
            case 0: goto L134;
            case 1: goto L134;
            case 2: goto L134;
            case 3: goto L134;
            case 4: goto L125;
            default: goto L126;
        };
     */
    /* JADX WARN: Code restructure failed: missing block: B:94:0x046d, code lost:
    
        r0 = new io.hetu.core.spi.cube.aggregator.AggregationSignature(io.hetu.core.spi.cube.CubeAggregateFunction.SUM.getName(), r0, r0);
        r0 = r16.get((java.lang.String) r12.cubeMetadata.getColumn(r0).orElseThrow(() -> { // java.util.function.Supplier.get():java.lang.Object
            return lambda$createScanNode$7(r1);
        }));
        r40 = null;
     */
    /* JADX WARN: Code restructure failed: missing block: B:95:0x04b4, code lost:
    
        if (r0.containsValue(r0) != false) goto L82;
     */
    /* JADX WARN: Code restructure failed: missing block: B:96:0x04b7, code lost:
    
        r0 = r12.metadata.getColumnMetadata(r12.session, r15, r0);
        r40 = r12.symbolAllocator.newSymbol("sum_" + r0 + "_" + r0.getName(), r0.getType());
        r0.add(r40);
        r0.put(r40, r0);
        r0.put(r40, r0);
        r0.add(new io.prestosql.sql.planner.optimizations.CubeRewriteResult.AggregatorSource(r40, r40));
     */
    /* JADX WARN: Code restructure failed: missing block: B:97:0x057a, code lost:
    
        r0 = new io.hetu.core.spi.cube.aggregator.AggregationSignature(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT.getName(), r0, r0);
        r0 = r16.get((java.lang.String) r12.cubeMetadata.getColumn(r0).orElseThrow(() -> { // java.util.function.Supplier.get():java.lang.Object
            return lambda$createScanNode$8(r1);
        }));
        r44 = null;
     */
    /* JADX WARN: Code restructure failed: missing block: B:98:0x05c1, code lost:
    
        if (r0.containsValue(r0) != false) goto L92;
     */
    /* JADX WARN: Code restructure failed: missing block: B:99:0x05c4, code lost:
    
        r0 = r12.metadata.getColumnMetadata(r12.session, r15, r0);
        r44 = r12.symbolAllocator.newSymbol("count_" + r0 + "_" + r0.getName(), r0.getType());
        r0.add(r44);
        r0.put(r44, r0);
        r0.put(r44, r0);
        r0.add(new io.prestosql.sql.planner.optimizations.CubeRewriteResult.AggregatorSource(r44, r44));
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public io.prestosql.sql.planner.optimizations.CubeRewriteResult createScanNode(io.prestosql.spi.plan.AggregationNode r13, io.prestosql.spi.plan.PlanNode r14, io.prestosql.spi.metadata.TableHandle r15, java.util.Map<java.lang.String, io.prestosql.spi.connector.ColumnHandle> r16) {
        /*
            Method dump skipped, instructions count: 1813
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: io.prestosql.sql.planner.optimizations.AggregationRewriteWithCube.createScanNode(io.prestosql.spi.plan.AggregationNode, io.prestosql.spi.plan.PlanNode, io.prestosql.spi.metadata.TableHandle, java.util.Map):io.prestosql.sql.planner.optimizations.CubeRewriteResult");
    }
}
