/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.boosting;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
import org.jetbrains.annotations.NotNull;

public abstract class GDBTrainer
extends DatasetTrainer<ModelsComposition, Double> {
    private final double gradientStep;
    private final int cntOfIterations;
    protected final Loss loss;
    protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001);

    public GDBTrainer(double gradStepSize, Integer cntOfIterations, Loss loss) {
        this.gradientStep = gradStepSize;
        this.cntOfIterations = cntOfIterations;
        this.loss = loss;
    }

    @Override
    public <K, V> ModelsComposition fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return this.updateModel((ModelsComposition)null, datasetBuilder, preprocessor);
    }

    @Override
    protected <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        if (!this.learnLabels(datasetBuilder, preprocessor)) {
            return this.getLastTrainedModelOrThrowEmptyDatasetException(mdl);
        }
        IgniteBiTuple<Double, Long> initAndSampleSize = this.computeInitialValue(this.envBuilder, datasetBuilder, preprocessor);
        if (initAndSampleSize == null) {
            return this.getLastTrainedModelOrThrowEmptyDatasetException(mdl);
        }
        Double mean = (Double)initAndSampleSize.get1();
        Long sampleSize = (Long)initAndSampleSize.get2();
        long learningStartTs = System.currentTimeMillis();
        GDBLearningStrategy stgy = this.getLearningStrategy().withBaseModelTrainerBuilder(this::buildBaseModelTrainer).withExternalLabelToInternal(this::externalLabelToInternal).withCntOfIterations(this.cntOfIterations).withEnvironmentBuilder(this.envBuilder).withLossGradient(this.loss).withSampleSize(sampleSize).withMeanLabelValue(mean).withDefaultGradStepSize(this.gradientStep).withCheckConvergenceStgyFactory(this.checkConvergenceStgyFactory);
        List<IgniteModel<Vector, Double>> models = mdl != null ? stgy.update((GDBModel)mdl, datasetBuilder, preprocessor) : stgy.learnModels(datasetBuilder, preprocessor);
        double learningTime = (double)(System.currentTimeMillis() - learningStartTs) / 1000.0;
        this.environment.logger(this.getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
        WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(stgy.getCompositionWeights(), stgy.getMeanValue());
        return new GDBModel(models, resAggregator, this::internalLabelToExternal);
    }

    @Override
    public boolean isUpdateable(ModelsComposition mdl) {
        return mdl instanceof GDBModel;
    }

    public GDBTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        return (GDBTrainer)super.withEnvironmentBuilder(envBuilder);
    }

    protected abstract <V, K> boolean learnLabels(DatasetBuilder<K, V> var1, Preprocessor<K, V> var2);

    @NotNull
    protected abstract DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> buildBaseModelTrainer();

    protected abstract double externalLabelToInternal(double var1);

    protected abstract double internalLabelToExternal(double var1);

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    protected <V, K, C extends Serializable> IgniteBiTuple<Double, Long> computeInitialValue(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> builder, Preprocessor<K, V> preprocessor) {
        this.learningEnvironment().initDeployingContext(preprocessor);
        try (Dataset dataset = builder.build(envBuilder, new EmptyContextBuilder(), new DecisionTreeDataBuilder(preprocessor, false), this.learningEnvironment());){
            IgniteBiTuple meanTuple = (IgniteBiTuple)dataset.compute(data -> {
                double sum = Arrays.stream(data.getLabels()).map(this::externalLabelToInternal).sum();
                return new IgniteBiTuple((Object)sum, (Object)data.getLabels().length);
            }, (a, b) -> {
                if (a == null) {
                    return b;
                }
                if (b == null) {
                    return a;
                }
                a.set1((Object)((Double)a.get1() + (Double)b.get1()));
                a.set2((Object)((Long)a.get2() + (Long)b.get2()));
                return a;
            });
            if (meanTuple != null) {
                meanTuple.set1((Object)((Double)meanTuple.get1() / (double)((Long)meanTuple.get2()).longValue()));
            }
            IgniteBiTuple igniteBiTuple = meanTuple;
            return igniteBiTuple;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public GDBTrainer withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory) {
        this.checkConvergenceStgyFactory = factory;
        return this;
    }

    protected GDBLearningStrategy getLearningStrategy() {
        return new GDBLearningStrategy();
    }

    public static final class GDBModel
    extends ModelsComposition {
        private static final long serialVersionUID = 3476661240155508004L;
        private final IgniteFunction<Double, Double> internalToExternalLblMapping;

        public GDBModel(List<? extends IgniteModel<Vector, Double>> models, WeightedPredictionsAggregator predictionsAggregator, IgniteFunction<Double, Double> internalToExternalLblMapping) {
            super(models, predictionsAggregator);
            this.internalToExternalLblMapping = internalToExternalLblMapping;
        }

        @Override
        public Double predict(Vector features) {
            return (Double)this.internalToExternalLblMapping.apply(super.predict(features));
        }
    }
}

