package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableSet;
import io.airlift.log.Logger;
import io.hetu.core.spi.cube.CubeMetadata;
import io.hetu.core.spi.cube.CubeStatement;
import io.hetu.core.spi.cube.io.CubeMetaStore;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.cube.CubeManager;
import io.prestosql.cube.CubeStatementGenerator;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.PrestoWarning;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.StandardWarningCode;
import io.prestosql.spi.metadata.TableHandle;
import io.prestosql.spi.plan.AggregationNode;
import io.prestosql.spi.plan.FilterNode;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.PlanNodeIdAllocator;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.plan.TableScanNode;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.spi.relation.CallExpression;
import io.prestosql.spi.relation.RowExpression;
import io.prestosql.spi.relation.VariableReferenceExpression;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.parser.ParsingOptions;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.sql.planner.ExpressionDomainTranslator;
import io.prestosql.sql.planner.PlanSymbolAllocator;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.relational.OriginalExpressionUtils;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.Literal;
import io.prestosql.sql.tree.SymbolReference;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.LongSupplier;
import java.util.stream.Collectors;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/StarTreeAggregationRule.class */
public class StarTreeAggregationRule implements Rule<AggregationNode> {
    private static final Logger LOGGER = Logger.get(StarTreeAggregationRule.class);
    public static final String AVG = "avg";
    public static final String COUNT = "count";
    public static final String SUM = "sum";
    public static final String MIN = "min";
    public static final String MAX = "max";
    private static final Set<String> SUPPORTED_FUNCTIONS = ImmutableSet.of(AVG, COUNT, SUM, MIN, MAX);
    private static final Capture<Optional<PlanNode>> OPTIONAL_PRE_PROJECT_ONE = Capture.newCapture();
    private static final Capture<Optional<PlanNode>> OPTIONAL_PRE_PROJECT_TWO = Capture.newCapture();
    private static final Capture<Optional<PlanNode>> OPTIONAL_FILTER = Capture.newCapture();
    private static final Capture<Optional<PlanNode>> OPTIONAL_POST_PROJECT = Capture.newCapture();
    private static final Capture<TableScanNode> TABLE_SCAN = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(StarTreeAggregationRule::isSupportedAggregation).with(Patterns.optionalSource(ProjectNode.class).matching(Patterns.anyPlan().capturedAsIf(planNode -> {
        return planNode instanceof ProjectNode;
    }, OPTIONAL_PRE_PROJECT_ONE).with(Patterns.optionalSource(ProjectNode.class).matching(Patterns.anyPlan().capturedAsIf(planNode2 -> {
        return planNode2 instanceof ProjectNode;
    }, OPTIONAL_PRE_PROJECT_TWO).with(Patterns.optionalSource(FilterNode.class).matching(Patterns.anyPlan().capturedAsIf(planNode3 -> {
        return planNode3 instanceof FilterNode;
    }, OPTIONAL_FILTER).with(Patterns.optionalSource(ProjectNode.class).matching(Patterns.anyPlan().capturedAsIf(planNode4 -> {
        return planNode4 instanceof ProjectNode;
    }, OPTIONAL_POST_PROJECT).with(Patterns.source().matching(Patterns.tableScan().capturedAs(TABLE_SCAN)))))))))));
    private final CubeManager cubeManager;
    private CubeMetaStore cubeMetaStore;
    private final Metadata metadata;

    public StarTreeAggregationRule(CubeManager cubeManager, Metadata metadata) {
        this.cubeManager = cubeManager;
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        if (!SystemSessionProperties.isEnableStarTreeIndex(session) || !this.cubeManager.getCubeProvider(CubeManager.STAR_TREE).isPresent()) {
            return false;
        }
        if (this.cubeMetaStore != null) {
            return true;
        }
        synchronized (this) {
            if (this.cubeMetaStore == null) {
                Optional<CubeMetaStore> metaStore = this.cubeManager.getMetaStore(CubeManager.STAR_TREE);
                if (!metaStore.isPresent()) {
                    return false;
                }
                this.cubeMetaStore = metaStore.get();
            }
            return true;
        }
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        long currentTimeMillis = System.currentTimeMillis();
        Optional optional = (Optional) captures.get(OPTIONAL_PRE_PROJECT_ONE);
        Optional optional2 = (Optional) captures.get(OPTIONAL_PRE_PROJECT_TWO);
        Optional optional3 = (Optional) captures.get(OPTIONAL_POST_PROJECT);
        Optional<PlanNode> optional4 = (Optional) captures.get(OPTIONAL_FILTER);
        TableScanNode tableScanNode = (TableScanNode) captures.get(TABLE_SCAN);
        if (!supportedProjectNode(optional) || !supportedProjectNode(optional2) || !supportedProjectNode(optional3)) {
            return Rule.Result.empty();
        }
        LinkedList linkedList = new LinkedList();
        optional.ifPresent(planNode -> {
            linkedList.add((ProjectNode) planNode);
        });
        optional2.ifPresent(planNode2 -> {
            linkedList.add((ProjectNode) planNode2);
        });
        optional3.ifPresent(planNode3 -> {
            linkedList.add((ProjectNode) planNode3);
        });
        try {
            try {
                Rule.Result optimize = optimize(aggregationNode, optional4, tableScanNode, buildSymbolMappings(aggregationNode, linkedList, optional4, tableScanNode), context.getSession(), context.getSymbolAllocator(), context.getIdAllocator(), context.getWarningCollector());
                LOGGER.debug("Star-tree total optimization time: %d millis", new Object[]{Long.valueOf(System.currentTimeMillis() - currentTimeMillis)});
                return optimize;
            } catch (IllegalArgumentException | IllegalStateException | UnsupportedOperationException | PrestoException e) {
                LOGGER.warn("Encountered exception '" + e.getMessage() + "' while applying the StartTreeAggregationRule", new Object[]{e});
                Rule.Result empty = Rule.Result.empty();
                LOGGER.debug("Star-tree total optimization time: %d millis", new Object[]{Long.valueOf(System.currentTimeMillis() - currentTimeMillis)});
                return empty;
            }
        } catch (Throwable th) {
            LOGGER.debug("Star-tree total optimization time: %d millis", new Object[]{Long.valueOf(System.currentTimeMillis() - currentTimeMillis)});
            throw th;
        }
    }

    public Rule.Result optimize(AggregationNode aggregationNode, Optional<PlanNode> optional, TableScanNode tableScanNode, Map<String, Object> map, Session session, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        TableHandle table = tableScanNode.getTable();
        String qualifiedObjectName = this.metadata.getTableMetadata(session, table).getQualifiedName().toString();
        Metadata metadata = this.metadata;
        Class<FilterNode> cls = FilterNode.class;
        FilterNode.class.getClass();
        CubeStatement generate = CubeStatementGenerator.generate(metadata, qualifiedObjectName, aggregationNode, (FilterNode) optional.map((v1) -> {
            return r4.cast(v1);
        }).orElse(null), map);
        if (generate.getAggregations().isEmpty()) {
            return Rule.Result.empty();
        }
        List list = (List) CubeMetadata.filter(this.cubeMetaStore.getMetadataList(generate.getFrom()), generate).stream().filter(cubeMetadata -> {
            Class<FilterNode> cls2 = FilterNode.class;
            FilterNode.class.getClass();
            return filterPredicateMatches(optional.map((v1) -> {
                return r2.cast(v1);
            }), cubeMetadata, session, planSymbolAllocator.getTypes());
        }).collect(Collectors.toList());
        if (list.isEmpty()) {
            return Rule.Result.empty();
        }
        LongSupplier tableLastModifiedTimeSupplier = this.metadata.getTableLastModifiedTimeSupplier(session, table);
        if (tableLastModifiedTimeSupplier == null) {
            warningCollector.add(new PrestoWarning(StandardWarningCode.EXPIRED_CUBE, "Unable to identify last modified time of " + qualifiedObjectName + ". Ignoring star tree cubes."));
            return Rule.Result.empty();
        }
        long asLong = tableLastModifiedTimeSupplier.getAsLong();
        List list2 = (List) list.stream().filter(cubeMetadata2 -> {
            return cubeMetadata2.getSourceTableLastUpdatedTime() >= asLong;
        }).collect(Collectors.toList());
        if (list2.isEmpty()) {
            warningCollector.add(new PrestoWarning(StandardWarningCode.EXPIRED_CUBE, qualifiedObjectName + " has been modified after creating cubes. Ignoring expired cubes."));
            return Rule.Result.empty();
        }
        list2.sort(Comparator.comparingLong((v0) -> {
            return v0.getLastUpdatedTime();
        }).reversed());
        return Rule.Result.ofPlanNode(new AggregationRewriteWithCube(this.metadata, session, planSymbolAllocator, planNodeIdAllocator, map, (CubeMetadata) list2.get(0)).rewrite(aggregationNode, optional.orElse(null)));
    }

    private boolean atLeastMatchesOne(List<Expression> list, TupleDomain<Symbol> tupleDomain, Session session, TypeProvider typeProvider) {
        return list.stream().anyMatch(expression -> {
            ExpressionDomainTranslator.ExtractionResult fromPredicate = ExpressionDomainTranslator.fromPredicate(this.metadata, session, expression, typeProvider);
            return BooleanLiteral.TRUE_LITERAL.equals(fromPredicate.getRemainingExpression()) && fromPredicate.getTupleDomain().contains(tupleDomain);
        });
    }

    private boolean filterPredicateMatches(Optional<FilterNode> optional, CubeMetadata cubeMetadata, Session session, TypeProvider typeProvider) {
        if (cubeMetadata.getPredicateString() == null) {
            return true;
        }
        if (!optional.isPresent()) {
            return false;
        }
        Expression rewriteIdentifiersToSymbolReferences = ExpressionUtils.rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(cubeMetadata.getPredicateString(), new ParsingOptions()));
        RowExpression predicate = optional.get().getPredicate();
        if (!OriginalExpressionUtils.isExpression(predicate)) {
            LOGGER.error("StarTree index cannot support predicate %s", new Object[]{predicate});
            return false;
        }
        ExpressionDomainTranslator.ExtractionResult fromPredicate = ExpressionDomainTranslator.fromPredicate(this.metadata, session, ExpressionUtils.rewriteIdentifiersToSymbolReferences(OriginalExpressionUtils.castToExpression(predicate)), typeProvider);
        if (BooleanLiteral.TRUE_LITERAL.equals(fromPredicate.getRemainingExpression())) {
            ExpressionDomainTranslator.ExtractionResult fromPredicate2 = ExpressionDomainTranslator.fromPredicate(this.metadata, session, rewriteIdentifiersToSymbolReferences, typeProvider);
            return !BooleanLiteral.TRUE_LITERAL.equals(fromPredicate2.getRemainingExpression()) ? atLeastMatchesOne(ExpressionUtils.extractDisjuncts(rewriteIdentifiersToSymbolReferences), fromPredicate.getTupleDomain(), session, typeProvider) : fromPredicate2.getTupleDomain().contains(fromPredicate.getTupleDomain());
        }
        LOGGER.error("StarTree index cannot support predicate %s", new Object[]{predicate});
        return false;
    }

    public static Map<String, Object> buildSymbolMappings(AggregationNode aggregationNode, List<ProjectNode> list, Optional<PlanNode> optional, TableScanNode tableScanNode) {
        ColumnHandle columnHandle;
        HashMap hashMap = new HashMap();
        aggregationNode.getOutputSymbols().stream().map((v0) -> {
            return v0.getName();
        }).forEach(str -> {
            hashMap.put(str, str);
        });
        aggregationNode.getAggregations().values().forEach(aggregation -> {
            SymbolsExtractor.extractUnique(aggregation).stream().map((v0) -> {
                return v0.getName();
            }).forEach(str2 -> {
                hashMap.put(str2, str2);
            });
        });
        optional.ifPresent(planNode -> {
            SymbolsExtractor.extractUnique(((FilterNode) planNode).getPredicate()).stream().map((v0) -> {
                return v0.getName();
            }).forEach(str2 -> {
                hashMap.put(str2, str2);
            });
        });
        Iterator<ProjectNode> it = list.iterator();
        while (it.hasNext()) {
            Map map = it.next().getAssignments().getMap();
            for (Map.Entry entry : hashMap.entrySet()) {
                RowExpression rowExpression = (RowExpression) map.get(new Symbol(String.valueOf(entry.getValue())));
                if (rowExpression != null) {
                    if (OriginalExpressionUtils.isExpression(rowExpression)) {
                        Expression castToExpression = OriginalExpressionUtils.castToExpression(rowExpression);
                        if (castToExpression instanceof Cast) {
                            castToExpression = ((Cast) castToExpression).getExpression();
                        }
                        if (castToExpression instanceof SymbolReference) {
                            entry.setValue(((SymbolReference) castToExpression).getName());
                        } else if (castToExpression instanceof Literal) {
                            entry.setValue(castToExpression);
                        }
                    } else {
                        if (rowExpression instanceof CallExpression) {
                            while (rowExpression instanceof CallExpression) {
                                rowExpression = (RowExpression) ((CallExpression) rowExpression).getArguments().get(0);
                            }
                        }
                        if (rowExpression instanceof VariableReferenceExpression) {
                            entry.setValue(((VariableReferenceExpression) rowExpression).getName());
                        }
                    }
                }
            }
        }
        Map assignments = tableScanNode.getAssignments();
        for (Map.Entry entry2 : hashMap.entrySet()) {
            Object value = entry2.getValue();
            if ((value instanceof String) && (columnHandle = (ColumnHandle) assignments.get(new Symbol((String) value))) != null) {
                entry2.setValue(columnHandle);
            }
        }
        return hashMap;
    }

    static boolean isSupportedAggregation(AggregationNode aggregationNode) {
        if (aggregationNode.getOutputSymbols().isEmpty()) {
            return false;
        }
        return aggregationNode.getAggregations().values().stream().allMatch(StarTreeAggregationRule::isSupported);
    }

    static boolean isSupported(AggregationNode.Aggregation aggregation) {
        return SUPPORTED_FUNCTIONS.contains(aggregation.getFunctionCall().getDisplayName()) && aggregation.getFunctionCall().getArguments().size() <= 1 && (!aggregation.isDistinct() || aggregation.getFunctionCall().getDisplayName().equals(COUNT));
    }

    static boolean supportedProjectNode(Optional<PlanNode> optional) {
        if (!optional.isPresent()) {
            return true;
        }
        if (!(optional.get() instanceof ProjectNode)) {
            return false;
        }
        for (Map.Entry entry : optional.get().getAssignments().entrySet()) {
            RowExpression rowExpression = (RowExpression) entry.getValue();
            if (OriginalExpressionUtils.isExpression(rowExpression)) {
                Cast castToExpression = OriginalExpressionUtils.castToExpression((RowExpression) entry.getValue());
                if (!(castToExpression instanceof SymbolReference) && !(castToExpression instanceof Literal)) {
                    if (!(castToExpression instanceof Cast)) {
                        return false;
                    }
                    Expression expression = castToExpression.getExpression();
                    return (expression instanceof SymbolReference) || (expression instanceof Literal);
                }
            } else if (!(rowExpression instanceof VariableReferenceExpression)) {
                if (!(rowExpression instanceof CallExpression)) {
                    return false;
                }
                while (rowExpression instanceof CallExpression) {
                    rowExpression = (RowExpression) ((CallExpression) rowExpression).getArguments().get(0);
                }
                return rowExpression instanceof VariableReferenceExpression;
            }
        }
        return true;
    }
}
