/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.boosting;

import java.util.Arrays;
import java.util.List;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
import org.jetbrains.annotations.NotNull;

public abstract class GDBTrainer
extends DatasetTrainer<ModelsComposition, Double> {
    private final double gradientStep;
    private final int cntOfIterations;
    protected final Loss loss;
    protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001);

    public GDBTrainer(double gradStepSize, Integer cntOfIterations, Loss loss) {
        this.gradientStep = gradStepSize;
        this.cntOfIterations = cntOfIterations;
        this.loss = loss;
    }

    @Override
    public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        return this.updateModel((ModelsComposition)null, datasetBuilder, featureExtractor, lbExtractor);
    }

    @Override
    protected <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        if (!this.learnLabels(datasetBuilder, featureExtractor, lbExtractor)) {
            return this.getLastTrainedModelOrThrowEmptyDatasetException(mdl);
        }
        IgniteBiTuple<Double, Long> initAndSampleSize = this.computeInitialValue(datasetBuilder, featureExtractor, lbExtractor);
        if (initAndSampleSize == null) {
            return this.getLastTrainedModelOrThrowEmptyDatasetException(mdl);
        }
        Double mean = (Double)initAndSampleSize.get1();
        Long sampleSize = (Long)initAndSampleSize.get2();
        long learningStartTs = System.currentTimeMillis();
        GDBLearningStrategy stgy = this.getLearningStrategy().withBaseModelTrainerBuilder(this::buildBaseModelTrainer).withExternalLabelToInternal(this::externalLabelToInternal).withCntOfIterations(this.cntOfIterations).withEnvironment(this.environment).withLossGradient(this.loss).withSampleSize(sampleSize).withMeanLabelValue(mean).withDefaultGradStepSize(this.gradientStep).withCheckConvergenceStgyFactory(this.checkConvergenceStgyFactory);
        List<Model<Vector, Double>> models = mdl != null ? stgy.update((GDBModel)mdl, datasetBuilder, featureExtractor, lbExtractor) : stgy.learnModels(datasetBuilder, featureExtractor, lbExtractor);
        double learningTime = (double)(System.currentTimeMillis() - learningStartTs) / 1000.0;
        this.environment.logger(this.getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
        WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(stgy.getCompositionWeights(), stgy.getMeanValue());
        return new GDBModel(models, resAggregator, this::internalLabelToExternal);
    }

    @Override
    protected boolean checkState(ModelsComposition mdl) {
        return mdl instanceof GDBModel;
    }

    protected abstract <V, K> boolean learnLabels(DatasetBuilder<K, V> var1, IgniteBiFunction<K, V, Vector> var2, IgniteBiFunction<K, V, Double> var3);

    @NotNull
    protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer();

    protected abstract double externalLabelToInternal(double var1);

    protected abstract double internalLabelToExternal(double var1);

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    protected <V, K> IgniteBiTuple<Double, Long> computeInitialValue(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        try (Dataset dataset = builder.build(new EmptyContextBuilder(), new DecisionTreeDataBuilder(featureExtractor, lbExtractor, false));){
            IgniteBiTuple meanTuple = (IgniteBiTuple)dataset.compute(data -> {
                double sum = Arrays.stream(data.getLabels()).map(this::externalLabelToInternal).sum();
                return new IgniteBiTuple((Object)sum, (Object)data.getLabels().length);
            }, (a, b) -> {
                if (a == null) {
                    return b;
                }
                if (b == null) {
                    return a;
                }
                a.set1((Object)((Double)a.get1() + (Double)b.get1()));
                a.set2((Object)((Long)a.get2() + (Long)b.get2()));
                return a;
            });
            if (meanTuple != null) {
                meanTuple.set1((Object)((Double)meanTuple.get1() / (double)((Long)meanTuple.get2()).longValue()));
            }
            IgniteBiTuple igniteBiTuple = meanTuple;
            return igniteBiTuple;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public GDBTrainer withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory) {
        this.checkConvergenceStgyFactory = factory;
        return this;
    }

    protected GDBLearningStrategy getLearningStrategy() {
        return new GDBLearningStrategy();
    }

    public static class GDBModel
    extends ModelsComposition {
        private static final long serialVersionUID = 3476661240155508004L;
        private final IgniteFunction<Double, Double> internalToExternalLblMapping;

        public GDBModel(List<? extends Model<Vector, Double>> models, WeightedPredictionsAggregator predictionsAggregator, IgniteFunction<Double, Double> internalToExternalLblMapping) {
            super(models, predictionsAggregator);
            this.internalToExternalLblMapping = internalToExternalLblMapping;
        }

        @Override
        public Double apply(Vector features) {
            return (Double)this.internalToExternalLblMapping.apply(super.apply(features));
        }
    }
}

