package org.apache.hadoop.hive.ql.intercept.rules;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import jodd.util.StringPool;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.Driver;
import org.apache.hadoop.hive.ql.hooks.Entity;
import org.apache.hadoop.hive.ql.hooks.ReadEntity;
import org.apache.hadoop.hive.ql.hooks.WriteEntity;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.parse.ASTNode;
import org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer;
import org.apache.hadoop.hive.ql.parse.QB;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.hadoop.yarn.client.RequestHedgingRMFailoverProxyProvider;
import org.apache.hadoop.yarn.client.api.YarnClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hadoop/hive/ql/intercept/rules/RuleUtils.class */
public class RuleUtils {
    private static final String RULE_ALL_TENANT_KEY = "A";
    public static final String SQL_IS_EXPLAIN = "sql.is.explain";
    public static final String SQL_IS_SELECT = "sql.is.select";
    public static final Logger LOG = LoggerFactory.getLogger(Rule.class);
    private static final RuleId[] RULE_AFTER_PARSE = {RuleId.STATIC_0001, RuleId.STATIC_0002, RuleId.STATIC_0003, RuleId.STATIC_0004};
    private static final RuleId[] MINIMAL_RULE = {RuleId.DYNAMIC_0001, RuleId.DYNAMIC_0002, RuleId.RUNNING_0002, RuleId.RUNNING_0003, RuleId.RUNNING_0004};
    public static long ruleFilesHintNum = -1;
    public static final Configuration CONF = new HiveConf();
    private static final LoadingCache<String, String> cache = CacheBuilder.newBuilder().expireAfterAccess(10, TimeUnit.MINUTES).refreshAfterWrite(10, TimeUnit.MINUTES).maximumSize(2000).build(new CacheLoader<String, String>() { // from class: org.apache.hadoop.hive.ql.intercept.rules.RuleUtils.1
        @Override // com.google.common.cache.CacheLoader
        public String load(String str) {
            try {
                YarnClient createYarnClient = YarnClient.createYarnClient();
                try {
                    RuleUtils.CONF.set("yarn.timeline-service.enabled", "false");
                    RuleUtils.CONF.set("yarn.client.failover-proxy-provider", RequestHedgingRMFailoverProxyProvider.class.getName());
                    createYarnClient.init(RuleUtils.CONF);
                    createYarnClient.start();
                    String defaultQueueName = createYarnClient.getUserInfo(str).getDefaultQueueName();
                    RuleUtils.LOG.info("current user {}, load tenant from yarn : {}", str, defaultQueueName);
                    if (createYarnClient != null) {
                        createYarnClient.close();
                    }
                    return defaultQueueName;
                } finally {
                }
            } catch (Exception e) {
                RuleUtils.LOG.error("{} get tenant from yarn failed", str, e);
                return null;
            }
        }
    });

    public static boolean valid(Rule rule) {
        if (rule.hintLimit >= 0 || rule.interceptLimit >= 0) {
            return true;
        }
        LOG.error("{} action does not config", rule.ruleId);
        return false;
    }

    public static boolean validIfExist(Rule rule) {
        if (rule.hintLimit > 0 || rule.interceptLimit > 0) {
            LOG.error("{} does not need action param", rule.ruleId);
            return false;
        }
        if (rule.actions.size() <= 1) {
            return true;
        }
        LOG.error("{} support only one action", rule.ruleId);
        return false;
    }

    public static boolean isRuleEnabled(Configuration configuration) {
        Map<String, Map<RuleId, Rule>> rules;
        return (!HiveConf.getBoolVar(configuration, HiveConf.ConfVars.HIVE_EXT_SQL_INTERCEPT_ENABLE) || (rules = RuleWatcher.getRules()) == null || rules.isEmpty()) ? false : true;
    }

    public static Rule getRule(Configuration configuration, RuleId ruleId, String str) {
        Map<String, Map<RuleId, Rule>> rules;
        Map<RuleId, Rule> map;
        if (!HiveConf.getBoolVar(configuration, HiveConf.ConfVars.HIVE_EXT_SQL_INTERCEPT_ENABLE) || (rules = RuleWatcher.getRules()) == null || rules.isEmpty()) {
            return null;
        }
        Rule rule = null;
        String tenant = getTenant(str);
        if (tenant != null && (map = rules.get(tenant)) != null && !map.isEmpty()) {
            rule = map.get(ruleId);
        }
        if (rule != null) {
            if (!rule.getExemptUsers().contains(str)) {
                return rule;
            }
            recordRuleExemptMsg(str, ruleId);
            return null;
        }
        Map<RuleId, Rule> map2 = rules.get(RULE_ALL_TENANT_KEY);
        if (map2 != null && !map2.isEmpty()) {
            rule = map2.get(ruleId);
        }
        if (rule == null || !rule.getExemptUsers().contains(str)) {
            return rule;
        }
        recordRuleExemptMsg(str, ruleId);
        return null;
    }

    public static int getInterceptLimit(Configuration configuration, RuleId ruleId) {
        String str = configuration.get(ruleId.name());
        if (str == null || str.isEmpty()) {
            return 0;
        }
        int indexOf = str.indexOf(",");
        if (indexOf < 0) {
            LOG.error("parse rule failed, id : {}, param : {}", ruleId, str);
            return -1;
        }
        try {
            int parseInt = Integer.parseInt(str.substring(indexOf + 1));
            if (parseInt > 0) {
                return parseInt;
            }
            return -1;
        } catch (NumberFormatException e) {
            LOG.error("parse rule failed, id : {}, param : {}", ruleId, str);
            return -1;
        }
    }

    private static void doFilter(String str, String str2, Object obj, RuleId... ruleIdArr) {
        Map<RuleId, Rule> map = RuleWatcher.getRules().get(str2);
        if (map == null || map.isEmpty()) {
            return;
        }
        for (RuleId ruleId : ruleIdArr) {
            Rule rule = map.get(ruleId);
            if (rule != null) {
                if (rule.getExemptUsers().contains(str)) {
                    recordRuleExemptMsg(str, ruleId);
                } else {
                    try {
                        rule.evaluate(obj);
                    } catch (Exception e) {
                        if (e instanceof RuleException) {
                            throw e;
                        }
                        LOG.error("do Filter failed", e);
                    }
                }
            }
        }
    }

    public static String getTenant(String str) {
        try {
            return cache.get(str);
        } catch (Exception e) {
            LOG.warn("get tenant for {} failed", str);
            return null;
        }
    }

    public static void doFilterSubQuery(QB qb, String str) {
        String tenant = getTenant(str);
        if (tenant != null) {
            doFilter(str, tenant, qb, RuleId.STATIC_0005);
        }
        doFilter(str, RULE_ALL_TENANT_KEY, qb, RuleId.STATIC_0005);
    }

    public static void doFilterSqlSize(String str, String str2) {
        String tenant = getTenant(str2);
        if (tenant != null) {
            doFilter(str2, tenant, str, RuleId.STATIC_0006);
        }
        doFilter(str2, RULE_ALL_TENANT_KEY, str, RuleId.STATIC_0006);
    }

    public static void doFilterAfterParseAst(Configuration configuration, ASTNode aSTNode, String str) {
        int type;
        if (configuration.getBoolean(SQL_IS_EXPLAIN, false) || (type = aSTNode.getType()) == 836 || type == 837) {
            return;
        }
        String tenant = getTenant(str);
        if (tenant != null) {
            doFilter(str, tenant, aSTNode, RULE_AFTER_PARSE);
        }
        doFilter(str, RULE_ALL_TENANT_KEY, aSTNode, RULE_AFTER_PARSE);
    }

    public static void doFilterCartesian(Configuration configuration, String str, boolean z) {
        if (configuration.getBoolean(SQL_IS_EXPLAIN, false)) {
            return;
        }
        String tenant = getTenant(str);
        if (tenant != null) {
            doFilter(str, tenant, Boolean.valueOf(z), RuleId.STATIC_0007);
        }
        doFilter(str, RULE_ALL_TENANT_KEY, Boolean.valueOf(z), RuleId.STATIC_0007);
    }

    public static void setRuleToConf(Configuration configuration, ASTNode aSTNode, String str) {
        setSqlTypeToConf(configuration, aSTNode);
        int type = aSTNode.getType();
        if (type == 836 || type == 837) {
            return;
        }
        String tenant = getTenant(str);
        if (tenant != null) {
            setMinimalRuleToConf(configuration, tenant, str);
        }
        setMinimalRuleToConf(configuration, RULE_ALL_TENANT_KEY, str);
    }

    private static void setMinimalRuleToConf(Configuration configuration, String str, String str2) {
        Map<RuleId, Rule> map = RuleWatcher.getRules().get(str);
        if (map == null || map.isEmpty()) {
            return;
        }
        for (RuleId ruleId : MINIMAL_RULE) {
            Rule rule = map.get(ruleId);
            if (rule != null) {
                if (rule.getExemptUsers().contains(str2)) {
                    recordRuleExemptMsg(str2, rule.ruleId);
                } else {
                    configuration.set(ruleId.name(), rule.getParam());
                }
            }
        }
    }

    private static void setSqlTypeToConf(Configuration configuration, ASTNode aSTNode) {
        int type = aSTNode.getType();
        if (type == 836 || type == 837) {
            configuration.setBoolean(SQL_IS_EXPLAIN, true);
        }
        if (type == 961) {
            configuration.setBoolean(SQL_IS_SELECT, true);
        }
    }

    public static boolean isTaskRuleEnabled(Configuration configuration) {
        return StringUtils.isNotEmpty(configuration.get(RuleId.RUNNING_0002.name())) || StringUtils.isNotEmpty(configuration.get(RuleId.RUNNING_0004.name()));
    }

    public static void doFilter(RuleId ruleId, int i, int i2, long j) {
        if (i2 > 0 && j >= i2) {
            String str = ruleId.getMessage() + " : " + i2 + ", current size : " + j;
            recordRuleMsgException(ruleId, str);
            throw new RuleException(str);
        }
        if (i <= 0 || j < i) {
            return;
        }
        String str2 = ruleId.getMessage() + " : " + i + ", current size : " + j;
        if (ruleId == RuleId.DYNAMIC_0001) {
            ruleFilesHintNum = j;
        }
        try {
            if (recordRuleMsgHint(ruleId, str2)) {
                Driver.logWarning(str2);
            }
        } catch (Throwable th) {
            LOG.warn(str2);
        }
    }

    public static void doFilterMinimal(Configuration configuration, RuleId ruleId, long j) {
        String str = configuration.get(ruleId.name());
        if (str == null || str.isEmpty()) {
            return;
        }
        doFilterMinimal(str, ruleId, j);
    }

    public static void doFilterMinimal(String str, RuleId ruleId, long j) {
        int indexOf = str.indexOf(",");
        if (indexOf < 0) {
            LOG.error("parse rule failed, id : {}, param : {}", ruleId, str);
            return;
        }
        try {
            doFilter(ruleId, Integer.parseInt(str.substring(0, indexOf)), Integer.parseInt(str.substring(indexOf + 1)), j);
        } catch (NumberFormatException e) {
            LOG.error("parse rule failed, id : {}, param : {}", ruleId, str);
        }
    }

    public static long doFilterAllocatedMemory(String str, String str2, int i, AtomicLong atomicLong, long j) {
        if (str == null) {
            return 0L;
        }
        long appValue = RestUtils.getAppValue(str2, "allocatedMB", i);
        if (appValue <= 0) {
            return 0L;
        }
        long addAndGet = atomicLong.addAndGet(appValue - j);
        LOG.info("{} allocatedMB : {}, total allocatedMB : {}", new Object[]{str2, Long.valueOf(appValue), Long.valueOf(addAndGet)});
        doFilterMinimal(str, RuleId.RUNNING_0002, addAndGet);
        return appValue;
    }

    public static void analyzeInputOutput(ASTNode aSTNode, BaseSemanticAnalyzer baseSemanticAnalyzer, SqlRuleHelper sqlRuleHelper) {
        StringBuilder sb = new StringBuilder("[");
        StringBuilder sb2 = new StringBuilder("[");
        int type = aSTNode.getType();
        HashMap<String, ReadEntity> aliasToTable = sqlRuleHelper.getAliasToTable();
        HashSet<WriteEntity> outputs = baseSemanticAnalyzer.getOutputs();
        switch (type) {
            case 789:
            case 790:
            case 792:
            case 961:
                aliasToTable.forEach((str, readEntity) -> {
                    if (readEntity.getType() == Entity.Type.TABLE && !readEntity.getTable().isDummyTable() && readEntity.getParents().isEmpty()) {
                        sb.append("\"").append(str.replace('@', '.')).append("\",");
                    }
                });
            case 895:
                if (!outputs.isEmpty()) {
                    HashSet hashSet = new HashSet();
                    outputs.forEach(writeEntity -> {
                        if (writeEntity.getType() == Entity.Type.TABLE || writeEntity.getType() == Entity.Type.PARTITION) {
                            Table table = writeEntity.getTable();
                            String str2 = table != null ? table.getDbName() + StringPool.DOT + table.getTableName() : null;
                            if (str2 == null || hashSet.contains(str2)) {
                                return;
                            }
                            hashSet.add(str2);
                            sb2.append("\"").append(str2).append("\",");
                        }
                    });
                    break;
                }
                break;
        }
        sqlRuleHelper.setSourceTable(sb.length() > 1 ? sb.substring(0, sb.length() - 1) + "]" : ((Object) sb) + "]");
        sqlRuleHelper.setSinkTable(sb2.length() > 1 ? sb2.substring(0, sb2.length() - 1) + "]" : ((Object) sb2) + "]");
    }

    public static void recordRuleMsgHint(Configuration configuration, RuleId ruleId, long j) {
        boolean z = true;
        String str = null;
        try {
            String str2 = configuration.get(ruleId.name());
            str = ruleId.getMessage() + " : " + Integer.parseInt(str2.substring(0, str2.indexOf(","))) + ", current size : " + j;
            z = SessionState.get().getSqlRuleHelper().addSqlDefenseHint(ruleId, str);
        } catch (Exception e) {
            LOG.warn("failed to record sql defense info " + e.getMessage());
        }
        if (!z || str == null) {
            return;
        }
        Driver.logWarning(str);
    }

    public static boolean recordRuleMsgHint(RuleId ruleId, String str) {
        try {
            return SessionState.get().getSqlRuleHelper().addSqlDefenseHint(ruleId, str);
        } catch (Exception e) {
            LOG.warn("failed to record sql defense info " + e.getMessage());
            return true;
        }
    }

    public static void recordRuleMsgException(RuleId ruleId, String str) {
        try {
            SessionState.get().getSqlRuleHelper().addSqlDefenseException(ruleId, str);
        } catch (Exception e) {
            LOG.warn("failed to record sql defense info " + e.getMessage());
        }
    }

    public static void recordRuleExemptMsg(String str, RuleId ruleId) {
        String str2 = "Skip sql defense due to user(" + str + ") is exempted";
        boolean z = true;
        try {
            SqlRuleHelper sqlRuleHelper = SessionState.get().getSqlRuleHelper();
            if (sqlRuleHelper != null) {
                z = sqlRuleHelper.addSqlDefenseExemptInfo(ruleId, str2);
            }
        } catch (Exception e) {
            LOG.warn("failed to record sql defense info " + e.getMessage());
        }
        if (z) {
            Driver.logWarning(str2 + " from rule(" + ruleId + ").");
        }
    }
}
