/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.boosting;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.jetbrains.annotations.NotNull;

public class GDBLearningStrategy {
    protected LearningEnvironmentBuilder envBuilder;
    protected LearningEnvironment trainerEnvironment;
    protected int cntOfIterations;
    protected Loss loss;
    protected IgniteFunction<Double, Double> externalLbToInternalMapping;
    protected IgniteSupplier<DatasetTrainer<? extends IgniteModel<Vector, Double>, Double>> baseMdlTrainerBuilder;
    protected double meanLbVal;
    protected long sampleSize;
    protected double[] compositionWeights;
    protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001);
    private double defaultGradStepSize;

    public <K, V> List<IgniteModel<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return this.update(null, datasetBuilder, preprocessor);
    }

    public <K, V> List<IgniteModel<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, final Preprocessor<K, V> preprocessor) {
        double[] weights;
        WeightedPredictionsAggregator aggregator;
        ModelsComposition currComposition;
        if (this.trainerEnvironment == null) {
            throw new IllegalStateException("Learning environment builder is not set.");
        }
        List<IgniteModel<Vector, Double>> models = this.initLearningState(mdlToUpdate);
        ConvergenceChecker<K, V> convCheck = this.checkConvergenceStgyFactory.create(this.sampleSize, this.externalLbToInternalMapping, this.loss, datasetBuilder, preprocessor);
        DatasetTrainer trainer = (DatasetTrainer)this.baseMdlTrainerBuilder.get();
        for (int i = 0; i < this.cntOfIterations && !convCheck.isConverged(this.envBuilder, datasetBuilder, currComposition = new ModelsComposition(models, aggregator = new WeightedPredictionsAggregator(weights = Arrays.copyOf(this.compositionWeights, models.size()), this.meanLbVal))); ++i) {
            Vectorizer.VectorizerAdapter extractor = new Vectorizer.VectorizerAdapter<K, V, Serializable, Double>(){

                @Override
                public LabeledVector<Double> extract(K k, V v) {
                    LabeledVector labeledVector = (LabeledVector)preprocessor.apply(k, v);
                    Object features = labeledVector.features();
                    Double realAnswer = (Double)GDBLearningStrategy.this.externalLbToInternalMapping.apply((Double)labeledVector.label());
                    Double mdlAnswer = currComposition.predict((Vector)features);
                    return new LabeledVector<Double>((Vector)features, -GDBLearningStrategy.this.loss.gradient(GDBLearningStrategy.this.sampleSize, realAnswer, mdlAnswer));
                }
            };
            long startTs = System.currentTimeMillis();
            models.add((IgniteModel<Vector, Double>)trainer.fit(datasetBuilder, extractor));
            double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
            this.trainerEnvironment.logger(this.getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
        }
        return models;
    }

    @NotNull
    protected List<IgniteModel<Vector, Double>> initLearningState(GDBTrainer.GDBModel mdlToUpdate) {
        ArrayList<IgniteModel<Vector, Double>> models = new ArrayList<IgniteModel<Vector, Double>>();
        if (mdlToUpdate != null) {
            models.addAll(mdlToUpdate.getModels());
            WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator)mdlToUpdate.getPredictionsAggregator();
            this.meanLbVal = aggregator.getBias();
            this.compositionWeights = new double[models.size() + this.cntOfIterations];
            System.arraycopy(aggregator.getWeights(), 0, this.compositionWeights, 0, models.size());
        } else {
            this.compositionWeights = new double[this.cntOfIterations];
        }
        Arrays.fill(this.compositionWeights, models.size(), this.compositionWeights.length, this.defaultGradStepSize);
        return models;
    }

    public GDBLearningStrategy withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        this.envBuilder = envBuilder;
        this.trainerEnvironment = envBuilder.buildForTrainer();
        return this;
    }

    public GDBLearningStrategy withCntOfIterations(int cntOfIterations) {
        this.cntOfIterations = cntOfIterations;
        return this;
    }

    public GDBLearningStrategy withLossGradient(Loss loss) {
        this.loss = loss;
        return this;
    }

    public GDBLearningStrategy withExternalLabelToInternal(IgniteFunction<Double, Double> externalLbToInternal) {
        this.externalLbToInternalMapping = externalLbToInternal;
        return this;
    }

    public GDBLearningStrategy withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends IgniteModel<Vector, Double>, Double>> buildBaseMdlTrainer) {
        this.baseMdlTrainerBuilder = buildBaseMdlTrainer;
        return this;
    }

    public GDBLearningStrategy withMeanLabelValue(double meanLbVal) {
        this.meanLbVal = meanLbVal;
        return this;
    }

    public GDBLearningStrategy withSampleSize(long sampleSize) {
        this.sampleSize = sampleSize;
        return this;
    }

    public GDBLearningStrategy withCompositionWeights(double[] compositionWeights) {
        this.compositionWeights = compositionWeights;
        return this;
    }

    public GDBLearningStrategy withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory) {
        this.checkConvergenceStgyFactory = factory;
        return this;
    }

    public GDBLearningStrategy withDefaultGradStepSize(double defaultGradStepSize) {
        this.defaultGradStepSize = defaultGradStepSize;
        return this;
    }

    public double[] getCompositionWeights() {
        return this.compositionWeights;
    }

    public double getMeanValue() {
        return this.meanLbVal;
    }
}

