/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.selection.cv;

import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.cv.CrossValidationResult;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
import org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator;
import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.metric.Metric;
import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
import org.apache.ignite.ml.selection.split.mapper.UniformMapper;
import org.apache.ignite.ml.trainers.DatasetTrainer;

public class CrossValidation<M extends Model<Vector, L>, L, K, V> {
    public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) {
        return this.score(trainer, scoreCalculator, ignite, upstreamCache, (IgniteBiPredicate & Serializable)(k, v) -> true, featureExtractor, lbExtractor, new SHA256UniformMapper(), cv);
    }

    public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) {
        return this.score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor, new SHA256UniformMapper(), cv);
    }

    public CrossValidationResult score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv, ParamGrid paramGrid) {
        List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
        CrossValidationResult cvRes = new CrossValidationResult();
        paramSets.forEach(paramSet -> {
            HashMap<String, Double> paramMap = new HashMap<String, Double>();
            for (int paramIdx = 0; paramIdx < ((Double[])paramSet).length; ++paramIdx) {
                String paramName = paramGrid.getParamNameByIndex(paramIdx);
                Double paramVal = paramSet[paramIdx];
                paramMap.put(paramName, paramVal);
                try {
                    String mtdName = "with" + paramName.substring(0, 1).toUpperCase() + paramName.substring(1);
                    Method trainerSetter = null;
                    for (Method method : trainer.getClass().getDeclaredMethods()) {
                        if (!method.getName().equals(mtdName)) continue;
                        trainerSetter = method;
                    }
                    if (trainerSetter == null) {
                        throw new NoSuchMethodException(mtdName);
                    }
                    trainerSetter.invoke((Object)trainer, paramVal);
                    continue;
                }
                catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
                    e.printStackTrace();
                }
            }
            double[] locScores = this.score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor, new SHA256UniformMapper(), cv);
            cvRes.addScores(locScores, paramMap);
            double locAvgScore = Arrays.stream(locScores).average().orElse(Double.MIN_VALUE);
            if (locAvgScore > cvRes.getBestAvgScore()) {
                cvRes.setBestScore(locScores);
                cvRes.setBestHyperParams(paramMap);
                System.out.println(((Object)paramMap).toString());
            }
        });
        return cvRes;
    }

    public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, UniformMapper<K, V> mapper, int cv) {
        return this.score(trainer, (IgniteBiPredicate<K, V> predicate) -> new CacheBasedDatasetBuilder(ignite, upstreamCache, (IgniteBiPredicate & Serializable)(k, v) -> filter.apply(k, v) && predicate.apply(k, v)), (IgniteBiPredicate<K, V> predicate, M mdl) -> new CacheBasedLabelPairCursor(upstreamCache, (IgniteBiPredicate & Serializable)(k, v) -> filter.apply(k, v) && !predicate.apply(k, v), featureExtractor, lbExtractor, mdl), featureExtractor, lbExtractor, scoreCalculator, mapper, cv);
    }

    public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) {
        return this.score(trainer, scoreCalculator, upstreamMap, (IgniteBiPredicate & Serializable)(k, v) -> true, parts, featureExtractor, lbExtractor, new SHA256UniformMapper(), cv);
    }

    public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap, IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) {
        return this.score(trainer, scoreCalculator, upstreamMap, filter, parts, featureExtractor, lbExtractor, new SHA256UniformMapper(), cv);
    }

    public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap, IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, UniformMapper<K, V> mapper, int cv) {
        return this.score(trainer, (IgniteBiPredicate<K, V> predicate) -> new LocalDatasetBuilder(upstreamMap, (IgniteBiPredicate & Serializable)(k, v) -> filter.apply(k, v) && predicate.apply(k, v), parts), (IgniteBiPredicate<K, V> predicate, M mdl) -> new LocalLabelPairCursor(upstreamMap, (IgniteBiPredicate & Serializable)(k, v) -> filter.apply(k, v) && !predicate.apply(k, v), featureExtractor, lbExtractor, mdl), featureExtractor, lbExtractor, scoreCalculator, mapper, cv);
    }

    private double[] score(DatasetTrainer<M, L> trainer, Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> datasetBuilderSupplier, BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> testDataIterSupplier, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, Metric<L> scoreCalculator, UniformMapper<K, V> mapper, int cv) {
        double[] scores = new double[cv];
        double foldSize = 1.0 / (double)cv;
        for (int i = 0; i < cv; ++i) {
            double from = foldSize * (double)i;
            double to = foldSize * (double)(i + 1);
            IgniteBiPredicate & Serializable trainSetFilter = (IgniteBiPredicate & Serializable)(k, v) -> {
                double pnt = mapper.map(k, v);
                return pnt < from || pnt > to;
            };
            DatasetBuilder<K, V> datasetBuilder = datasetBuilderSupplier.apply(trainSetFilter);
            M mdl = trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
            try (LabelPairCursor<L> cursor = testDataIterSupplier.apply(trainSetFilter, mdl);){
                scores[i] = scoreCalculator.score(cursor.iterator());
                continue;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        return scores;
    }
}

