/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.selection.cv;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.environment.parallelism.Promise;
import org.apache.ignite.ml.math.functions.IgniteDoubleConsumer;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.pipeline.Pipeline;
import org.apache.ignite.ml.pipeline.PipelineMdl;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.selection.cv.CrossValidationResult;
import org.apache.ignite.ml.selection.paramgrid.BruteForceStrategy;
import org.apache.ignite.ml.selection.paramgrid.EvolutionOptimizationStrategy;
import org.apache.ignite.ml.selection.paramgrid.HyperParameterTuningStrategy;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
import org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator;
import org.apache.ignite.ml.selection.paramgrid.RandomStrategy;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.Metric;
import org.apache.ignite.ml.selection.scoring.metric.MetricName;
import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
import org.apache.ignite.ml.selection.split.mapper.UniformMapper;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.util.genetic.Chromosome;
import org.apache.ignite.ml.util.genetic.GeneticAlgorithm;
import org.jetbrains.annotations.NotNull;

public abstract class AbstractCrossValidation<M extends IgniteModel<Vector, Double>, K, V>
implements Serializable {
    protected LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder();
    protected LearningEnvironment environment = this.envBuilder.buildForTrainer();
    protected DatasetTrainer<M, Double> trainer;
    protected Pipeline<K, V, Integer, Double> pipeline;
    protected Metric metric;
    protected Preprocessor<K, V> preprocessor;
    protected IgniteBiPredicate<K, V> filter = (IgniteBiPredicate & Serializable)(k, v) -> true;
    protected int amountOfFolds;
    protected int parts;
    protected ParamGrid paramGrid;
    protected boolean isRunningOnPipeline = true;
    protected UniformMapper<K, V> mapper = new SHA256UniformMapper();

    public CrossValidationResult tuneHyperParameters() {
        HyperParameterTuningStrategy hyperParamTuningStgy = this.paramGrid.getHyperParameterTuningStrategy();
        if (hyperParamTuningStgy instanceof BruteForceStrategy) {
            return this.scoreBruteForceHyperparameterOptimization();
        }
        if (hyperParamTuningStgy instanceof RandomStrategy) {
            return this.scoreRandomSearchHyperparameterOptimization();
        }
        if (hyperParamTuningStgy instanceof EvolutionOptimizationStrategy) {
            return this.scoreEvolutionAlgorithmSearchHyperparameterOptimization();
        }
        throw new UnsupportedOperationException("This strategy is not supported yet [strategy=" + this.paramGrid.getHyperParameterTuningStrategy().getName() + "]");
    }

    private CrossValidationResult scoreEvolutionAlgorithmSearchHyperparameterOptimization() {
        EvolutionOptimizationStrategy stgy = (EvolutionOptimizationStrategy)this.paramGrid.getHyperParameterTuningStrategy();
        List<Double[]> paramSets = new ParameterSetGenerator(this.paramGrid.getParamValuesByParamIdx()).generate();
        ArrayList<Double[]> paramSetsCp = new ArrayList<Double[]>(paramSets);
        Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
        int sizeOfPopulation = 20;
        List<Double[]> rndParamSets = paramSetsCp.subList(0, sizeOfPopulation);
        CrossValidationResult cvRes = new CrossValidationResult();
        Function<Chromosome, Double> fitnessFunction = chromosome -> {
            TaskResult tr = this.calculateScoresForFixedParamSet(chromosome.toDoubleArray());
            cvRes.addScores(tr.locScores, tr.paramMap);
            return Arrays.stream(tr.locScores).average().orElse(Double.MIN_VALUE);
        };
        Random rnd = new Random(stgy.getSeed());
        BiFunction<Integer, Double, Double> mutator = (geneIdx, geneValue) -> {
            Double[] possibleGeneValues = this.paramGrid.getParamRawData().get((int)geneIdx);
            Double newGeneVal = possibleGeneValues[rnd.nextInt(possibleGeneValues.length)];
            return newGeneVal;
        };
        GeneticAlgorithm ga = new GeneticAlgorithm(rndParamSets);
        ga.withFitnessFunction(fitnessFunction).withMutationOperator(mutator).withAmountOfEliteChromosomes(stgy.getNumberOfEliteChromosomes()).withCrossingoverProbability(stgy.getCrossingoverProbability()).withCrossoverStgy(stgy.getCrossoverStgy()).withAmountOfGenerations(stgy.getNumberOfGenerations()).withSelectionStgy(stgy.getSelectionStgy()).withMutationProbability(stgy.getMutationProbability());
        if (this.environment.parallelismStrategy().getParallelism() > 1) {
            ga.runParallel(this.environment);
        } else {
            ga.run();
        }
        return cvRes;
    }

    private CrossValidationResult scoreRandomSearchHyperparameterOptimization() {
        RandomStrategy stgy = (RandomStrategy)this.paramGrid.getHyperParameterTuningStrategy();
        List<Double[]> paramSets = new ParameterSetGenerator(this.paramGrid.getParamValuesByParamIdx()).generate();
        ArrayList<Double[]> paramSetsCp = new ArrayList<Double[]>(paramSets);
        Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
        CrossValidationResult cvRes = new CrossValidationResult();
        List rndParamSets = paramSetsCp.subList(0, stgy.getMaxTries());
        List tasks = rndParamSets.stream().map(paramSet -> () -> this.calculateScoresForFixedParamSet((Double[])paramSet)).collect(Collectors.toList());
        List<TaskResult> taskResults = this.environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
        taskResults.forEach(tr -> cvRes.addScores(((TaskResult)tr).locScores, ((TaskResult)tr).paramMap));
        return cvRes;
    }

    private CrossValidationResult scoreBruteForceHyperparameterOptimization() {
        List<Double[]> paramSets = new ParameterSetGenerator(this.paramGrid.getParamValuesByParamIdx()).generate();
        CrossValidationResult cvRes = new CrossValidationResult();
        List tasks = paramSets.stream().map(paramSet -> () -> this.calculateScoresForFixedParamSet((Double[])paramSet)).collect(Collectors.toList());
        List<TaskResult> taskResults = this.environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
        taskResults.forEach(tr -> cvRes.addScores(((TaskResult)tr).locScores, ((TaskResult)tr).paramMap));
        return cvRes;
    }

    private TaskResult calculateScoresForFixedParamSet(Double[] paramSet) {
        Map<String, Double> paramMap = this.injectAndGetParametersFromPipeline(this.paramGrid, paramSet);
        double[] locScores = this.scoreByFolds();
        return new TaskResult(paramMap, locScores);
    }

    public abstract double[] scoreByFolds();

    @NotNull
    private Map<String, Double> injectAndGetParametersFromPipeline(ParamGrid paramGrid, Double[] paramSet) {
        HashMap<String, Double> paramMap = new HashMap<String, Double>();
        for (int paramIdx = 0; paramIdx < paramSet.length; ++paramIdx) {
            IgniteDoubleConsumer setter = paramGrid.getSetterByIndex(paramIdx);
            Double paramVal = paramSet[paramIdx];
            setter.accept(paramVal);
            paramMap.put(paramGrid.getParamNameByIndex(paramIdx), paramVal);
        }
        return paramMap;
    }

    protected double[] score(Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> datasetBuilderSupplier) {
        double[] scores = new double[this.amountOfFolds];
        double foldSize = 1.0 / (double)this.amountOfFolds;
        for (int i = 0; i < this.amountOfFolds; ++i) {
            double from = foldSize * (double)i;
            double to = foldSize * (double)(i + 1);
            IgniteBiPredicate & Serializable trainSetFilter = (IgniteBiPredicate & Serializable)(k, v) -> {
                double pnt = this.mapper.map(k, v);
                return pnt < from || pnt > to;
            };
            IgniteBiPredicate & Serializable testSetFilter = (IgniteBiPredicate & Serializable)(k, v) -> !trainSetFilter.apply(k, v);
            DatasetBuilder<K, V> trainSet = datasetBuilderSupplier.apply(trainSetFilter);
            M mdl = this.trainer.fit(trainSet, this.preprocessor);
            DatasetBuilder<K, V> testSet = datasetBuilderSupplier.apply(testSetFilter);
            scores[i] = Evaluator.evaluate(testSet, mdl, this.preprocessor, new Metric[]{this.metric}).getSingle();
        }
        return scores;
    }

    protected double[] scorePipeline(Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> datasetBuilderSupplier) {
        double[] scores = new double[this.amountOfFolds];
        double foldSize = 1.0 / (double)this.amountOfFolds;
        for (int i = 0; i < this.amountOfFolds; ++i) {
            double from = foldSize * (double)i;
            double to = foldSize * (double)(i + 1);
            IgniteBiPredicate & Serializable trainSetFilter = (IgniteBiPredicate & Serializable)(k, v) -> {
                double pnt = this.mapper.map(k, v);
                return pnt < from || pnt > to;
            };
            IgniteBiPredicate & Serializable testSetFilter = (IgniteBiPredicate & Serializable)(k, v) -> !trainSetFilter.apply(k, v);
            DatasetBuilder<K, V> trainSet = datasetBuilderSupplier.apply(trainSetFilter);
            PipelineMdl<K, V> mdl = this.pipeline.fit(trainSet);
            DatasetBuilder<K, V> testSet = datasetBuilderSupplier.apply(testSetFilter);
            scores[i] = Evaluator.evaluate(testSet, mdl, this.pipeline.getFinalPreprocessor(), this.metric).getSingle();
        }
        return scores;
    }

    public AbstractCrossValidation<M, K, V> withTrainer(DatasetTrainer<M, Double> trainer) {
        this.trainer = trainer;
        return this;
    }

    public AbstractCrossValidation<M, K, V> withMetric(MetricName metric) {
        this.metric = metric.create();
        return this;
    }

    public AbstractCrossValidation<M, K, V> withPreprocessor(Preprocessor<K, V> preprocessor) {
        this.preprocessor = preprocessor;
        return this;
    }

    public AbstractCrossValidation<M, K, V> withFilter(IgniteBiPredicate<K, V> filter) {
        this.filter = filter;
        return this;
    }

    public AbstractCrossValidation<M, K, V> withAmountOfFolds(int amountOfFolds) {
        this.amountOfFolds = amountOfFolds;
        return this;
    }

    public AbstractCrossValidation<M, K, V> withParamGrid(ParamGrid paramGrid) {
        this.paramGrid = paramGrid;
        return this;
    }

    public AbstractCrossValidation<M, K, V> isRunningOnPipeline(boolean runningOnPipeline) {
        this.isRunningOnPipeline = runningOnPipeline;
        return this;
    }

    public AbstractCrossValidation<M, K, V> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        this.envBuilder = envBuilder;
        this.environment = envBuilder.buildForTrainer();
        return this;
    }

    public AbstractCrossValidation<M, K, V> withPipeline(Pipeline<K, V, Integer, Double> pipeline) {
        this.pipeline = pipeline;
        return this;
    }

    public AbstractCrossValidation<M, K, V> withMapper(UniformMapper<K, V> mapper) {
        this.mapper = mapper;
        return this;
    }

    public static class TaskResult {
        private Map<String, Double> paramMap;
        private double[] locScores;

        public TaskResult(Map<String, Double> paramMap, double[] locScores) {
            this.paramMap = Collections.unmodifiableMap(paramMap);
            this.locScores = (double[])locScores.clone();
        }

        public void setParamMap(Map<String, Double> paramMap) {
            this.paramMap = Collections.unmodifiableMap(paramMap);
        }

        public void setLocScores(double[] locScores) {
            this.locScores = (double[])locScores.clone();
        }
    }
}

