/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.composition.ModelOnFeaturesSubspace;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.environment.parallelism.Promise;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.util.Utils;
import org.jetbrains.annotations.NotNull;

public abstract class BaggingModelTrainer
extends DatasetTrainer<ModelsComposition, Double> {
    private final PredictionsAggregator predictionsAggregator;
    private final int maximumFeaturesCntPerMdl;
    private final int ensembleSize;
    private final double samplePartSizePerMdl;
    private final int featureVectorSize;

    public BaggingModelTrainer(PredictionsAggregator predictionsAggregator, int featureVectorSize, int maximumFeaturesCntPerMdl, int ensembleSize, double samplePartSizePerMdl) {
        this.predictionsAggregator = predictionsAggregator;
        this.maximumFeaturesCntPerMdl = maximumFeaturesCntPerMdl;
        this.ensembleSize = ensembleSize;
        this.samplePartSizePerMdl = samplePartSizePerMdl;
        this.featureVectorSize = featureVectorSize;
    }

    @Override
    public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        MLLogger log = this.environment.logger(this.getClass());
        log.log(MLLogger.VerboseLevel.LOW, "Start learning", new Object[0]);
        Long startTs = System.currentTimeMillis();
        ArrayList tasks = new ArrayList();
        for (int i = 0; i < this.ensembleSize; ++i) {
            tasks.add(() -> this.learnModel(datasetBuilder, featureExtractor, lbExtractor));
        }
        List models = this.environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
        double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
        log.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
        log.log(MLLogger.VerboseLevel.LOW, "Learning finished", new Object[0]);
        return new ModelsComposition(models, this.predictionsAggregator);
    }

    @NotNull
    private <K, V> ModelOnFeaturesSubspace learnModel(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        Random rnd = new Random();
        SHA256UniformMapper sampleFilter = new SHA256UniformMapper(rnd);
        long featureExtractorSeed = rnd.nextLong();
        Map<Integer, Integer> featuresMapping = this.createFeaturesMapping(featureExtractorSeed, this.featureVectorSize);
        Long startTs = System.currentTimeMillis();
        Model<Vector, Double> mdl = this.buildDatasetTrainerForModel().fit(datasetBuilder.withFilter((IgniteBiPredicate & Serializable)(features, answer) -> sampleFilter.map(features, answer) < this.samplePartSizePerMdl), this.wrapFeatureExtractor(featureExtractor, featuresMapping), lbExtractor);
        double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
        this.environment.logger(this.getClass()).log(MLLogger.VerboseLevel.HIGH, "One model training time was %.2fs", learningTime);
        return new ModelOnFeaturesSubspace(featuresMapping, mdl);
    }

    private Map<Integer, Integer> createFeaturesMapping(long seed, int featuresVectorSize) {
        int[] featureIdxs = Utils.selectKDistinct(featuresVectorSize, this.maximumFeaturesCntPerMdl, new Random(seed));
        HashMap<Integer, Integer> locFeaturesMapping = new HashMap<Integer, Integer>();
        IntStream.range(0, this.maximumFeaturesCntPerMdl).forEach(localId -> locFeaturesMapping.put(localId, featureIdxs[localId]));
        return locFeaturesMapping;
    }

    protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> buildDatasetTrainerForModel();

    private <K, V> IgniteBiFunction<K, V, Vector> wrapFeatureExtractor(IgniteBiFunction<K, V, Vector> featureExtractor, Map<Integer, Integer> featureMapping) {
        return featureExtractor.andThen(featureValues -> {
            double[] newFeaturesValues = new double[featureMapping.size()];
            featureMapping.forEach((localId, featureValueId) -> {
                newFeaturesValues[localId.intValue()] = featureValues.get((int)featureValueId);
            });
            return VectorUtils.of(newFeaturesValues);
        });
    }

    @Override
    public <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        ArrayList<Model<Vector, Double>> newModels = new ArrayList<Model<Vector, Double>>(mdl.getModels());
        newModels.addAll(((ModelsComposition)this.fit((DatasetBuilder)datasetBuilder, (IgniteBiFunction)featureExtractor, (IgniteBiFunction)lbExtractor)).getModels());
        return new ModelsComposition(newModels, this.predictionsAggregator);
    }
}

