/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.bagging;

import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.bagging.BaggedModel;
import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
import org.apache.ignite.ml.util.Utils;

public class BaggedTrainer<L>
extends DatasetTrainer<BaggedModel, L> {
    private final DatasetTrainer<? extends IgniteModel, L> tr;
    private final PredictionsAggregator aggregator;
    private final int ensembleSize;
    private final double subsampleRatio;
    private final int featuresVectorSize;
    private final int featureSubspaceDim;

    public BaggedTrainer(DatasetTrainer<? extends IgniteModel, L> tr, PredictionsAggregator aggregator, int ensembleSize, double subsampleRatio, int featuresVectorSize, int featureSubspaceDim) {
        this.tr = tr;
        this.aggregator = aggregator;
        this.ensembleSize = ensembleSize;
        this.subsampleRatio = subsampleRatio;
        this.featuresVectorSize = featuresVectorSize;
        this.featureSubspaceDim = featureSubspaceDim;
    }

    private DatasetTrainer<IgniteModel<Vector, Double>, L> getTrainer() {
        List mappings = this.featuresVectorSize > 0 && this.featureSubspaceDim != this.featuresVectorSize ? IntStream.range(0, this.ensembleSize).mapToObj(modelIdx -> BaggedTrainer.getMapping(this.featuresVectorSize, this.featureSubspaceDim, this.environment.randomNumbersGenerator().nextLong())).collect(Collectors.toList()) : null;
        List<DatasetTrainer<? extends IgniteModel, L>> trainers = Collections.nCopies(this.ensembleSize, this.tr);
        List subspaceTrainers = IntStream.range(0, this.ensembleSize).mapToObj(mdlIdx -> {
            AdaptableDatasetTrainer tr = AdaptableDatasetTrainer.of((DatasetTrainer)trainers.get(mdlIdx));
            if (mappings != null) {
                tr = tr.afterFeatureExtractor(featureValues -> {
                    int[] mapping = (int[])mappings.get(mdlIdx);
                    double[] newFeaturesValues = new double[mapping.length];
                    for (int j = 0; j < mapping.length; ++j) {
                        newFeaturesValues[j] = featureValues.get(mapping[j]);
                    }
                    return VectorUtils.of(newFeaturesValues);
                }).beforeTrainedModel(VectorUtils.getProjector((int[])mappings.get(mdlIdx)));
            }
            return tr.withUpstreamTransformerBuilder(BaggingUpstreamTransformer.builder(this.subsampleRatio, mdlIdx)).withEnvironmentBuilder(this.envBuilder);
        }).map(CompositionUtils::unsafeCoerce).collect(Collectors.toList());
        AdaptableDatasetTrainer finalTrainer = AdaptableDatasetTrainer.of(new TrainersParallelComposition(subspaceTrainers)).afterTrainedModel(l -> (Double)this.aggregator.apply(l.stream().mapToDouble(Double::valueOf).toArray()));
        return CompositionUtils.unsafeCoerce(finalTrainer);
    }

    public static int[] getMapping(int featuresVectorSize, int maximumFeaturesCntPerMdl, long seed) {
        return Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed));
    }

    @Override
    public <K, V> BaggedModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        IgniteModel<Vector, Double> fit = this.getTrainer().fit(datasetBuilder, preprocessor);
        return new BaggedModel(fit);
    }

    @Override
    public <K, V> BaggedModel update(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        this.learningEnvironment().initDeployingContext(preprocessor);
        IgniteModel<Vector, Double> updated = this.getTrainer().update(mdl.model(), datasetBuilder, preprocessor);
        return new BaggedModel(updated);
    }

    public BaggedTrainer<L> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        return (BaggedTrainer)super.withEnvironmentBuilder(envBuilder);
    }

    @Override
    public boolean isUpdateable(BaggedModel mdl) {
        throw new IllegalStateException();
    }

    @Override
    protected <K, V> BaggedModel updateModel(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        throw new IllegalStateException();
    }
}

