/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.trainers;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.environment.parallelism.Promise;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.functions.IgniteTriFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
import org.apache.ignite.ml.util.Utils;

public class TrainerTransformers {
    public static <L> BaggedTrainer<L> makeBagged(DatasetTrainer<? extends IgniteModel, L> trainer, int ensembleSize, double subsampleRatio, PredictionsAggregator aggregator) {
        return TrainerTransformers.makeBagged(trainer, ensembleSize, subsampleRatio, -1, -1, aggregator);
    }

    public static <M extends IgniteModel<Vector, Double>, L> BaggedTrainer<L> makeBagged(DatasetTrainer<M, L> trainer, int ensembleSize, double subsampleRatio, int featureVectorSize, int featuresSubspaceDim, PredictionsAggregator aggregator) {
        return new BaggedTrainer<L>(trainer, aggregator, ensembleSize, subsampleRatio, featureVectorSize, featuresSubspaceDim);
    }

    private static <K, V, M extends IgniteModel<Vector, Double>> ModelsComposition runOnEnsemble(IgniteTriFunction<DatasetBuilder<K, V>, Integer, IgniteBiFunction<K, V, Vector>, IgniteSupplier<M>> trainingTaskGenerator, DatasetBuilder<K, V> datasetBuilder, int ensembleSize, double subsampleRatio, int featuresVectorSize, int featureSubspaceDim, IgniteBiFunction<K, V, Vector> extractor, PredictionsAggregator aggregator, LearningEnvironment environment) {
        MLLogger log = environment.logger(datasetBuilder.getClass());
        log.log(MLLogger.VerboseLevel.LOW, "Start learning.", new Object[0]);
        List mappings = null;
        if (featuresVectorSize > 0 && featureSubspaceDim != featuresVectorSize) {
            mappings = IntStream.range(0, ensembleSize).mapToObj(modelIdx -> TrainerTransformers.getMapping(featuresVectorSize, featureSubspaceDim, environment.randomNumbersGenerator().nextLong() + (long)modelIdx)).collect(Collectors.toList());
        }
        Long startTs = System.currentTimeMillis();
        ArrayList tasks = new ArrayList();
        ArrayList<IgniteBiFunction<K, V, Vector>> extractors = new ArrayList<IgniteBiFunction<K, V, Vector>>();
        if (mappings != null) {
            for (int[] mapping : mappings) {
                extractors.add(TrainerTransformers.wrapExtractor(extractor, mapping));
            }
        }
        for (int i = 0; i < ensembleSize; ++i) {
            DatasetBuilder<K, V> newBuilder = datasetBuilder.withUpstreamTransformer(BaggingUpstreamTransformer.builder(subsampleRatio, i));
            tasks.add(trainingTaskGenerator.apply(newBuilder, i, mappings != null ? (IgniteBiFunction)extractors.get(i) : extractor));
        }
        List models = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).map(ModelWithMapping::new).collect(Collectors.toList());
        if (mappings != null) {
            for (int i = 0; i < models.size(); ++i) {
                ((ModelWithMapping)models.get(i)).setMapping(VectorUtils.getProjector((int[])mappings.get(i)));
            }
        }
        double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
        log.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs.", learningTime);
        log.log(MLLogger.VerboseLevel.LOW, "Learning finished.", new Object[0]);
        return new ModelsComposition(models, aggregator);
    }

    public static int[] getMapping(int featuresVectorSize, int maximumFeaturesCntPerMdl, long seed) {
        return Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed));
    }

    private static <K, V> IgniteBiFunction<K, V, Vector> wrapExtractor(IgniteBiFunction<K, V, Vector> featureExtractor, int[] featureMapping) {
        return featureExtractor.andThen(featureValues -> {
            double[] newFeaturesValues = new double[featureMapping.length];
            for (int i = 0; i < featureMapping.length; ++i) {
                newFeaturesValues[i] = featureValues.get(featureMapping[i]);
            }
            return VectorUtils.of(newFeaturesValues);
        });
    }

    private static class ModelWithMapping<X, Y, M extends IgniteModel<X, Y>>
    implements IgniteModel<X, Y> {
        private final M model;
        private IgniteFunction<X, X> mapping;

        public ModelWithMapping(M model) {
            this(model, x -> x);
        }

        public ModelWithMapping(M model, IgniteFunction<X, X> mapping) {
            this.model = model;
            this.mapping = mapping;
        }

        public void setMapping(IgniteFunction<X, X> mapping) {
            this.mapping = mapping;
        }

        @Override
        public Y predict(X x) {
            return (Y)this.model.predict(this.mapping.apply(x));
        }

        public M model() {
            return this.model;
        }

        public IgniteFunction<X, X> mapping() {
            return this.mapping;
        }
    }
}

