package io.prestosql.sql.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.UnmodifiableIterator;
import io.airlift.log.Logger;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.spi.ConnectorPlanOptimizer;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.StandardErrorCode;
import io.prestosql.spi.connector.CatalogName;
import io.prestosql.spi.plan.AggregationNode;
import io.prestosql.spi.plan.ExceptNode;
import io.prestosql.spi.plan.FilterNode;
import io.prestosql.spi.plan.GroupIdNode;
import io.prestosql.spi.plan.IntersectNode;
import io.prestosql.spi.plan.JoinNode;
import io.prestosql.spi.plan.LimitNode;
import io.prestosql.spi.plan.MarkDistinctNode;
import io.prestosql.spi.plan.PlanNode;
import io.prestosql.spi.plan.PlanNodeIdAllocator;
import io.prestosql.spi.plan.ProjectNode;
import io.prestosql.spi.plan.Symbol;
import io.prestosql.spi.plan.TableScanNode;
import io.prestosql.spi.plan.TopNNode;
import io.prestosql.spi.plan.UnionNode;
import io.prestosql.spi.plan.WindowNode;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.PlanSymbolAllocator;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.plan.DeleteNode;
import io.prestosql.sql.planner.plan.TableWriterNode;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/ApplyConnectorOptimization.class */
public class ApplyConnectorOptimization implements PlanOptimizer {
    static final Set<Class<? extends PlanNode>> CONNECTOR_ACCESSIBLE_PLAN_NODES = ImmutableSet.of(AggregationNode.class, TableScanNode.class, LimitNode.class, ExceptNode.class, FilterNode.class, IntersectNode.class, new Class[]{MarkDistinctNode.class, JoinNode.class, WindowNode.class, ProjectNode.class, TopNNode.class, UnionNode.class, GroupIdNode.class});
    private static final Logger log = Logger.get(ApplyConnectorOptimization.class);
    private static final CatalogName EMPTY_CATALOG_NAME = new CatalogName("$internal$" + ApplyConnectorOptimization.class + "_CATALOG");
    private final Supplier<Map<CatalogName, Set<ConnectorPlanOptimizer>>> connectorOptimizersSupplier;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/ApplyConnectorOptimization$ConnectorPlanNodeContext.class */
    public static final class ConnectorPlanNodeContext {
        private final PlanNode parent;
        private final Set<CatalogName> reachableConnectors;
        private final Set<Class<? extends PlanNode>> reachablePlanNodeTypes;

        ConnectorPlanNodeContext(PlanNode planNode, Set<CatalogName> set, Set<Class<? extends PlanNode>> set2) {
            this.parent = planNode;
            this.reachableConnectors = (Set) Objects.requireNonNull(set, "reachableConnectors is null");
            this.reachablePlanNodeTypes = (Set) Objects.requireNonNull(set2, "reachablePlanNodeTypes is null");
            Preconditions.checkArgument(!set.isEmpty(), "encountered a PlanNode that reaches no connector");
            Preconditions.checkArgument(!set2.isEmpty(), "encountered a PlanNode that reaches no plan node");
        }

        Optional<PlanNode> getParent() {
            return Optional.ofNullable(this.parent);
        }

        public Set<CatalogName> getReachableConnectors() {
            return this.reachableConnectors;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public Set<Class<? extends PlanNode>> getReachablePlanNodeTypes() {
            return this.reachablePlanNodeTypes;
        }

        boolean isClosure(CatalogName catalogName) {
            if (this.reachableConnectors.size() == 1 && this.reachableConnectors.contains(catalogName)) {
                return ApplyConnectorOptimization.containsAll(ApplyConnectorOptimization.CONNECTOR_ACCESSIBLE_PLAN_NODES, this.reachablePlanNodeTypes);
            }
            return false;
        }
    }

    public ApplyConnectorOptimization(Supplier<Map<CatalogName, Set<ConnectorPlanOptimizer>>> supplier) {
        this.connectorOptimizersSupplier = (Supplier) Objects.requireNonNull(supplier, "connectorOptimizerSupplier is null");
    }

    @Override // io.prestosql.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(typeProvider, "types is null");
        Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        Map<CatalogName, Set<ConnectorPlanOptimizer>> map = this.connectorOptimizersSupplier.get();
        if (map.isEmpty()) {
            return planNode;
        }
        ImmutableSet.Builder builder = ImmutableSet.builder();
        getAllCatalogNames(planNode, builder);
        UnmodifiableIterator it = builder.build().iterator();
        while (it.hasNext()) {
            CatalogName catalogName = (CatalogName) it.next();
            Set<ConnectorPlanOptimizer> set = map.get(catalogName);
            if (set != null) {
                ImmutableMap.Builder builder2 = ImmutableMap.builder();
                try {
                    buildConnectorPlanContext(planNode, null, builder2, set);
                    ImmutableMap build = builder2.build();
                    HashMap hashMap = new HashMap();
                    HashMap hashMap2 = new HashMap();
                    for (Map.Entry<Symbol, Type> entry : typeProvider.allTypes().entrySet()) {
                        hashMap2.put(entry.getKey().getName(), entry.getValue());
                    }
                    for (PlanNode planNode2 : build.keySet()) {
                        ConnectorPlanNodeContext connectorPlanNodeContext = (ConnectorPlanNodeContext) build.get(planNode2);
                        if (connectorPlanNodeContext.isClosure(catalogName) && connectorPlanNodeContext.getParent().isPresent() && !((ConnectorPlanNodeContext) build.get(connectorPlanNodeContext.getParent().get())).isClosure(catalogName)) {
                            PlanNode planNode3 = planNode2;
                            Iterator<ConnectorPlanOptimizer> it2 = set.iterator();
                            while (it2.hasNext()) {
                                planNode3 = it2.next().optimize(planNode3, session.toConnectorSession(catalogName), hashMap2, planSymbolAllocator, planNodeIdAllocator);
                            }
                            if (planNode2 != planNode3) {
                                Preconditions.checkState(containsAll(ImmutableSet.copyOf(planNode3.getOutputSymbols()), planNode2.getOutputSymbols()), "the connector optimizer from %s returns a node that does not cover all output before optimization", catalogName);
                                hashMap.put(planNode2, planNode3);
                            }
                        }
                    }
                    LinkedList linkedList = new LinkedList(hashMap.keySet());
                    while (!linkedList.isEmpty()) {
                        PlanNode planNode4 = (PlanNode) linkedList.poll();
                        if (((ConnectorPlanNodeContext) build.get(planNode4)).getParent().isPresent()) {
                            PlanNode planNode5 = ((ConnectorPlanNodeContext) build.get(planNode4)).getParent().get();
                            ImmutableList.Builder builder3 = ImmutableList.builder();
                            planNode5.getSources().forEach(planNode6 -> {
                                builder3.add(hashMap.getOrDefault(planNode6, planNode6));
                            });
                            hashMap.put(planNode5, planNode5.replaceChildren(builder3.build()));
                            linkedList.add(planNode5);
                        } else {
                            planNode = (PlanNode) hashMap.get(planNode4);
                        }
                    }
                } catch (PrestoException e) {
                    log.debug("Not able to apply optimization: %s", new Object[]{e.getMessage()});
                    return planNode;
                }
            }
        }
        return planNode;
    }

    private static void getAllCatalogNames(PlanNode planNode, ImmutableSet.Builder<CatalogName> builder) {
        if (!planNode.getSources().isEmpty()) {
            Iterator it = planNode.getSources().iterator();
            while (it.hasNext()) {
                getAllCatalogNames((PlanNode) it.next(), builder);
            }
        } else if (planNode instanceof TableScanNode) {
            builder.add(((TableScanNode) planNode).getTable().getCatalogName());
        } else {
            builder.add(EMPTY_CATALOG_NAME);
        }
    }

    private static ConnectorPlanNodeContext buildConnectorPlanContext(PlanNode planNode, PlanNode planNode2, ImmutableMap.Builder<PlanNode, ConnectorPlanNodeContext> builder, Set<ConnectorPlanOptimizer> set) {
        Set hashSet;
        Set hashSet2;
        if (!planNode.getSources().isEmpty()) {
            hashSet = new HashSet();
            hashSet2 = new HashSet();
            for (PlanNode planNode3 : planNode.getSources()) {
                if (planNode3 instanceof DeleteNode) {
                    Iterator<ConnectorPlanOptimizer> it = set.iterator();
                    while (it.hasNext()) {
                        if (!it.next().isSupportingDeleteNode()) {
                            throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "DeleteNode is not supported");
                        }
                    }
                }
                if (planNode3 instanceof TableWriterNode) {
                    Iterator<ConnectorPlanOptimizer> it2 = set.iterator();
                    while (it2.hasNext()) {
                        if (!it2.next().isSupportingTableWriterNode()) {
                            throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "TableWriterNode is not supported");
                        }
                    }
                }
                ConnectorPlanNodeContext buildConnectorPlanContext = buildConnectorPlanContext(planNode3, planNode, builder, set);
                hashSet.addAll(buildConnectorPlanContext.getReachableConnectors());
                hashSet2.addAll(buildConnectorPlanContext.getReachablePlanNodeTypes());
            }
            hashSet2.add(planNode.getClass());
        } else if (planNode instanceof TableScanNode) {
            hashSet = ImmutableSet.of(((TableScanNode) planNode).getTable().getCatalogName());
            hashSet2 = ImmutableSet.of(TableScanNode.class);
        } else {
            hashSet = ImmutableSet.of(EMPTY_CATALOG_NAME);
            hashSet2 = ImmutableSet.of(planNode.getClass());
        }
        ConnectorPlanNodeContext connectorPlanNodeContext = new ConnectorPlanNodeContext(planNode2, hashSet, hashSet2);
        builder.put(planNode, connectorPlanNodeContext);
        return connectorPlanNodeContext;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <T> boolean containsAll(Set<T> set, Collection<T> collection) {
        Iterator<T> it = collection.iterator();
        while (it.hasNext()) {
            if (!set.contains(it.next())) {
                return false;
            }
        }
        return true;
    }
}
