/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.selection.scoring.evaluator;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.Query;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.knn.KNNModel;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.selection.scoring.evaluator.EvaluationResult;
import org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator;
import org.apache.ignite.ml.selection.scoring.evaluator.context.EvaluationContext;
import org.apache.ignite.ml.selection.scoring.metric.Metric;
import org.apache.ignite.ml.selection.scoring.metric.MetricName;
import org.apache.ignite.ml.structures.LabeledVector;

public class Evaluator {
    public static <K, V> double evaluate(IgniteCache<K, V> dataCache, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor, Metric metric) {
        return Evaluator.evaluate(dataCache, (IgniteBiPredicate & Serializable)(k, v) -> true, mdl, preprocessor, metric);
    }

    public static <K, V> double evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor, Metric metric) {
        return Evaluator.evaluate(new CacheBasedDatasetBuilder<K, V>(Ignition.ignite(), dataCache, filter), mdl, preprocessor, metric).getSingle();
    }

    public static <K, V> double evaluate(IgniteCache<K, V> dataCache, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor, MetricName metric) {
        return Evaluator.evaluate(dataCache, mdl, preprocessor, metric.create());
    }

    public static <K, V> double evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor, MetricName metric) {
        return Evaluator.evaluate(new CacheBasedDatasetBuilder<K, V>(Ignition.ignite(), dataCache, filter), mdl, preprocessor, metric.create()).getSingle();
    }

    public static <K, V> double evaluate(Map<K, V> dataCache, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor, Metric metric) {
        return Evaluator.evaluate(dataCache, (IgniteBiPredicate & Serializable)(k, v) -> true, mdl, preprocessor, metric);
    }

    public static <K, V> double evaluate(Map<K, V> dataCache, IgniteBiPredicate<K, V> filter, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor, Metric metric) {
        return Evaluator.evaluate(new LocalDatasetBuilder<K, V>(dataCache, filter, 1), mdl, preprocessor, metric).getSingle();
    }

    public static <K, V> double evaluate(Map<K, V> dataCache, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor, MetricName metric) {
        return Evaluator.evaluate(dataCache, (IgniteBiPredicate & Serializable)(k, v) -> true, mdl, preprocessor, metric.create());
    }

    public static <K, V> double evaluate(Map<K, V> dataCache, IgniteBiPredicate<K, V> filter, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor, MetricName metric) {
        return Evaluator.evaluate(dataCache, filter, mdl, preprocessor, metric.create());
    }

    public static <K, V> EvaluationResult evaluateBinaryClassification(IgniteCache<K, V> dataCache, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor) {
        return Evaluator.evaluateBinaryClassification(dataCache, (IgniteBiPredicate & Serializable)(k, v) -> true, mdl, preprocessor);
    }

    public static <K, V> EvaluationResult evaluateBinaryClassification(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor) {
        Metric[] metrics = Evaluator.merge(MetricName.ACCURACY, MetricName.PRECISION, MetricName.RECALL, MetricName.F_MEASURE, MetricName.BALANCED_ACCURACY, MetricName.FALL_OUT, MetricName.FDR, MetricName.MISS_RATE, MetricName.NPV, MetricName.SPECIFICITY, MetricName.TRUE_POSITIVE, MetricName.FALSE_POSITIVE, MetricName.TRUE_NEGATIVE, MetricName.FALSE_NEGATIVE);
        return Evaluator.evaluate(new CacheBasedDatasetBuilder<K, V>(Ignition.ignite(), dataCache, filter), mdl, preprocessor, metrics);
    }

    public static <K, V> EvaluationResult evaluateBinaryClassification(Map<K, V> dataCache, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor) {
        return Evaluator.evaluateBinaryClassification(dataCache, (IgniteBiPredicate & Serializable)(k, v) -> true, mdl, preprocessor);
    }

    public static <K, V> EvaluationResult evaluateBinaryClassification(Map<K, V> dataCache, IgniteBiPredicate<K, V> filter, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor) {
        return Evaluator.evaluate(new LocalDatasetBuilder<K, V>(dataCache, filter, 1), mdl, preprocessor, Evaluator.merge(MetricName.ACCURACY, MetricName.PRECISION, MetricName.RECALL, MetricName.F_MEASURE));
    }

    public static <K, V> EvaluationResult evaluateRegression(IgniteCache<K, V> dataCache, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor) {
        return Evaluator.evaluateRegression(dataCache, (IgniteBiPredicate & Serializable)(k, v) -> true, mdl, preprocessor);
    }

    public static <K, V> EvaluationResult evaluateRegression(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor) {
        return Evaluator.evaluate(new CacheBasedDatasetBuilder<K, V>(Ignition.ignite(), dataCache, filter), mdl, preprocessor, Evaluator.merge(MetricName.MAE, MetricName.MSE, MetricName.R2, MetricName.RMSE, MetricName.RSS));
    }

    public static <K, V> EvaluationResult evaluateRegression(Map<K, V> dataCache, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor) {
        return Evaluator.evaluateRegression(dataCache, (IgniteBiPredicate & Serializable)(k, v) -> true, mdl, preprocessor);
    }

    public static <K, V> EvaluationResult evaluateRegression(Map<K, V> dataCache, IgniteBiPredicate<K, V> filter, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor) {
        return Evaluator.evaluate(new LocalDatasetBuilder<K, V>(dataCache, filter, 1), mdl, preprocessor, Evaluator.merge(MetricName.MAE, MetricName.MSE, MetricName.R2, MetricName.RMSE, MetricName.RSS));
    }

    private static Metric[] merge(MetricName name1, MetricName name2, MetricName ... metricNames) {
        Metric[] metrics = new Metric[metricNames.length + 2];
        metrics[0] = name1.create();
        metrics[1] = name2.create();
        for (int i = 0; i < metricNames.length; ++i) {
            metrics[i + 2] = metricNames[i].create();
        }
        return metrics;
    }

    private static <K, V> EvaluationResult evaluate(IgniteModel<Vector, Double> mdl, IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, Preprocessor<K, V> preprocessor, Metric[] metrics) {
        return Evaluator.evaluate(new CacheBasedDatasetBuilder<K, V>(Ignition.ignite(), dataCache, filter), mdl, preprocessor, metrics);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public static <K, V> EvaluationResult evaluate(DatasetBuilder<K, V> datasetBuilder, IgniteModel<Vector, Double> mdl, Preprocessor<K, V> preprocessor, Metric ... metrics) {
        try (Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder(), new FeatureMatrixWithLabelsOnHeapDataBuilder(preprocessor), LearningEnvironment.DEFAULT_TRAINER_ENV);){
            IgniteCache cache = null;
            if (datasetBuilder instanceof CacheBasedDatasetBuilder) {
                cache = ((CacheBasedDatasetBuilder)datasetBuilder).getUpstreamCache();
            }
            EvaluationResult evaluationResult = Evaluator.evaluate(mdl, dataset, cache, preprocessor, metrics);
            return evaluationResult;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static <K, V> EvaluationResult evaluate(IgniteModel<Vector, Double> mdl, Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, IgniteCache<K, V> cache, Preprocessor<K, V> preprocessor, Metric[] metrics) {
        HashMap<MetricName, Metric> metricMap = new HashMap<MetricName, Metric>();
        HashMap metricToAggrCls = new HashMap();
        for (Metric metric : metrics) {
            Object aggregator = metric.makeAggregator();
            MetricName name = metric.name();
            metricToAggrCls.put(name, aggregator.getClass());
            metricMap.put(name, metric);
        }
        HashMap<MetricName, Double> res = new HashMap<MetricName, Double>();
        Map<Class, EvaluationContext> aggrClsToCtx = Evaluator.initEvaluationContexts(dataset, metrics);
        Map<Class, MetricStatsAggregator> aggrClsToAggr = Evaluator.computeStats(mdl, dataset, cache, preprocessor, aggrClsToCtx, metrics);
        for (Metric metric : metrics) {
            MetricName name = metric.name();
            Class aggrCls = (Class)metricToAggrCls.get((Object)name);
            MetricStatsAggregator aggr = aggrClsToAggr.get(aggrCls);
            res.put(name, ((Metric)metricMap.get((Object)name)).initBy(aggr).value());
        }
        return new EvaluationResult(res);
    }

    private static Map<Class, EvaluationContext> initEvaluationContexts(Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, Metric ... metrics) {
        long nonEmptyCtxsCnt = Arrays.stream(metrics).map(x -> x.makeAggregator().createInitializedContext()).filter(x -> ((EvaluationContext)x).needToCompute()).count();
        if (nonEmptyCtxsCnt == 0L) {
            HashMap<Class, EvaluationContext> res = new HashMap<Class, EvaluationContext>();
            int n = 0;
            Metric[] metricArray = metrics;
            int n2 = metricArray.length;
            if (n < n2) {
                Metric m = metricArray[n];
                Object aggregator = m.makeAggregator();
                res.put(aggregator.getClass(), (EvaluationContext)m.makeAggregator().createInitializedContext());
                return res;
            }
        }
        return (Map)dataset.compute(data -> {
            HashMap<Class, Object> aggrs = new HashMap<Class, Object>();
            for (Metric m : metrics) {
                Object aggregator = m.makeAggregator();
                if (aggrs.containsKey(aggregator.getClass())) continue;
                aggrs.put(aggregator.getClass(), aggregator);
            }
            HashMap aggrToEvCtx = new HashMap();
            aggrs.forEach((clazz, aggr) -> aggrToEvCtx.put(clazz, (EvaluationContext)aggr.createInitializedContext()));
            for (int i = 0; i < data.getLabels().length; ++i) {
                LabeledVector<Double> vector = VectorUtils.of(data.getFeatures()[i]).labeled(data.getLabels()[i]);
                aggrToEvCtx.values().forEach(ctx -> ctx.aggregate(vector));
            }
            return aggrToEvCtx;
        }, (left, right) -> {
            if (left == null && right == null) {
                return new HashMap();
            }
            if (left == null) {
                return right;
            }
            if (right == null) {
                return left;
            }
            HashMap<Class, EvaluationContext> res = new HashMap<Class, EvaluationContext>();
            for (Class key : left.keySet()) {
                EvaluationContext ctx1 = (EvaluationContext)left.get(key);
                EvaluationContext ctx2 = (EvaluationContext)right.get(key);
                A.ensure((ctx1 != null && ctx2 != null ? 1 : 0) != 0, (String)"ctx1 != null && ctx2 != null");
                res.put(key, ctx1.mergeWith(ctx2));
            }
            return res;
        });
    }

    private static <K, V> Map<Class, MetricStatsAggregator> computeStats(IgniteModel<Vector, Double> mdl, Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, IgniteCache<K, V> cache, Preprocessor<K, V> preprocessor, Map<Class, EvaluationContext> ctxs, Metric ... metrics) {
        if (Evaluator.isOnlyLocalEstimation(mdl) && cache != null) {
            Map<Class, MetricStatsAggregator> aggrs = Evaluator.initAggregators(ctxs, metrics);
            try (QueryCursor qry = cache.query((Query)new ScanQuery());){
                qry.iterator().forEachRemaining(kv -> {
                    LabeledVector vector = (LabeledVector)preprocessor.apply(kv.getKey(), kv.getValue());
                    for (Class key : aggrs.keySet()) {
                        MetricStatsAggregator aggr = (MetricStatsAggregator)aggrs.get(key);
                        aggr.aggregate(mdl, vector);
                    }
                });
            }
            return aggrs;
        }
        return (Map)dataset.compute(data -> {
            Map<Class, MetricStatsAggregator> aggrs = Evaluator.initAggregators(ctxs, metrics);
            for (int i = 0; i < data.getLabels().length; ++i) {
                LabeledVector<Double> vector = VectorUtils.of(data.getFeatures()[i]).labeled(data.getLabels()[i]);
                for (Class key : aggrs.keySet()) {
                    MetricStatsAggregator aggr = aggrs.get(key);
                    aggr.aggregate(mdl, vector);
                }
            }
            return aggrs;
        }, (left, right) -> {
            if (left == null && right == null) {
                return new HashMap();
            }
            if (left == null) {
                return right;
            }
            if (right == null) {
                return left;
            }
            HashMap<Class, MetricStatsAggregator> res = new HashMap<Class, MetricStatsAggregator>();
            for (Class key : left.keySet()) {
                MetricStatsAggregator agg1 = (MetricStatsAggregator)left.get(key);
                MetricStatsAggregator agg2 = (MetricStatsAggregator)right.get(key);
                A.ensure((agg1 != null && agg2 != null ? 1 : 0) != 0, (String)"agg1 != null && agg2 != null");
                res.put(key, agg1.mergeWith(agg2));
            }
            return res;
        });
    }

    private static Map<Class, MetricStatsAggregator> initAggregators(Map<Class, EvaluationContext> ctxs, Metric[] metrics) {
        HashMap<Class, MetricStatsAggregator> aggrs = new HashMap<Class, MetricStatsAggregator>();
        for (Metric m : metrics) {
            Object aggregator = m.makeAggregator();
            EvaluationContext ctx = ctxs.get(aggregator.getClass());
            A.ensure((ctx != null ? 1 : 0) != 0, (String)"ctx != null");
            aggregator.initByContext((EvaluationContext)ctx);
            if (aggrs.containsKey(aggregator.getClass())) continue;
            aggrs.put(aggregator.getClass(), (MetricStatsAggregator)aggregator);
        }
        return aggrs;
    }

    private static boolean isOnlyLocalEstimation(IgniteModel<Vector, Double> mdl) {
        return mdl instanceof KNNModel;
    }
}

