/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.boosting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.jetbrains.annotations.NotNull;

public class GDBLearningStrategy {
    protected LearningEnvironment environment;
    protected int cntOfIterations;
    protected Loss loss;
    protected IgniteFunction<Double, Double> externalLbToInternalMapping;
    protected IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> baseMdlTrainerBuilder;
    protected double meanLbVal;
    protected long sampleSize;
    protected double[] compositionWeights;
    protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001);
    private double defaultGradStepSize;

    public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        return this.update(null, datasetBuilder, featureExtractor, lbExtractor);
    }

    public <K, V> List<Model<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        double[] weights;
        WeightedPredictionsAggregator aggregator;
        ModelsComposition currComposition;
        List<Model<Vector, Double>> models = this.initLearningState(mdlToUpdate);
        ConvergenceChecker<K, V> convCheck = this.checkConvergenceStgyFactory.create(this.sampleSize, this.externalLbToInternalMapping, this.loss, datasetBuilder, featureExtractor, lbExtractor);
        DatasetTrainer trainer = (DatasetTrainer)this.baseMdlTrainerBuilder.get();
        for (int i = 0; i < this.cntOfIterations && !convCheck.isConverged(datasetBuilder, currComposition = new ModelsComposition(models, aggregator = new WeightedPredictionsAggregator(weights = Arrays.copyOf(this.compositionWeights, models.size()), this.meanLbVal))); ++i) {
            IgniteBiFunction lbExtractorWrap = (k, v) -> {
                Double realAnswer = (Double)this.externalLbToInternalMapping.apply((Double)lbExtractor.apply(k, v));
                Double mdlAnswer = currComposition.apply((Vector)featureExtractor.apply(k, v));
                return -this.loss.gradient(this.sampleSize, realAnswer, mdlAnswer);
            };
            long startTs = System.currentTimeMillis();
            models.add((Model<Vector, Double>)trainer.fit(datasetBuilder, featureExtractor, lbExtractorWrap));
            double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
            this.environment.logger(this.getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
        }
        return models;
    }

    @NotNull
    protected List<Model<Vector, Double>> initLearningState(GDBTrainer.GDBModel mdlToUpdate) {
        ArrayList<Model<Vector, Double>> models = new ArrayList<Model<Vector, Double>>();
        if (mdlToUpdate != null) {
            models.addAll(mdlToUpdate.getModels());
            WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator)mdlToUpdate.getPredictionsAggregator();
            this.meanLbVal = aggregator.getBias();
            this.compositionWeights = new double[models.size() + this.cntOfIterations];
            for (int i = 0; i < models.size(); ++i) {
                this.compositionWeights[i] = aggregator.getWeights()[i];
            }
        } else {
            this.compositionWeights = new double[this.cntOfIterations];
        }
        Arrays.fill(this.compositionWeights, models.size(), this.compositionWeights.length, this.defaultGradStepSize);
        return models;
    }

    public GDBLearningStrategy withEnvironment(LearningEnvironment environment) {
        this.environment = environment;
        return this;
    }

    public GDBLearningStrategy withCntOfIterations(int cntOfIterations) {
        this.cntOfIterations = cntOfIterations;
        return this;
    }

    public GDBLearningStrategy withLossGradient(Loss loss) {
        this.loss = loss;
        return this;
    }

    public GDBLearningStrategy withExternalLabelToInternal(IgniteFunction<Double, Double> externalLbToInternal) {
        this.externalLbToInternalMapping = externalLbToInternal;
        return this;
    }

    public GDBLearningStrategy withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> buildBaseMdlTrainer) {
        this.baseMdlTrainerBuilder = buildBaseMdlTrainer;
        return this;
    }

    public GDBLearningStrategy withMeanLabelValue(double meanLbVal) {
        this.meanLbVal = meanLbVal;
        return this;
    }

    public GDBLearningStrategy withSampleSize(long sampleSize) {
        this.sampleSize = sampleSize;
        return this;
    }

    public GDBLearningStrategy withCompositionWeights(double[] compositionWeights) {
        this.compositionWeights = compositionWeights;
        return this;
    }

    public GDBLearningStrategy withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory) {
        this.checkConvergenceStgyFactory = factory;
        return this;
    }

    public GDBLearningStrategy withDefaultGradStepSize(double defaultGradStepSize) {
        this.defaultGradStepSize = defaultGradStepSize;
        return this;
    }

    public double[] getCompositionWeights() {
        return this.compositionWeights;
    }

    public double getMeanValue() {
        return this.meanLbVal;
    }
}

