/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.boosting;

import java.util.Arrays;
import java.util.List;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.DecisionTree;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;

public class GDBOnTreesLearningStrategy
extends GDBLearningStrategy {
    private boolean useIdx;

    public GDBOnTreesLearningStrategy(boolean useIdx) {
        this.useIdx = useIdx;
    }

    @Override
    public <K, V> List<Model<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
        DatasetTrainer trainer = (DatasetTrainer)this.baseMdlTrainerBuilder.get();
        assert (trainer instanceof DecisionTree);
        DecisionTree decisionTreeTrainer = (DecisionTree)trainer;
        List<Model<Vector, Double>> models = this.initLearningState(mdlToUpdate);
        ConvergenceChecker<K, V> convCheck = this.checkConvergenceStgyFactory.create(this.sampleSize, this.externalLbToInternalMapping, this.loss, datasetBuilder, featureExtractor, lbExtractor);
        try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(new EmptyContextBuilder(), new DecisionTreeDataBuilder(featureExtractor, lbExtractor, this.useIdx));){
            for (int i = 0; i < this.cntOfIterations; ++i) {
                double[] weights = Arrays.copyOf(this.compositionWeights, models.size());
                WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, this.meanLbVal);
                ModelsComposition currComposition = new ModelsComposition(models, aggregator);
                if (convCheck.isConverged(dataset, currComposition)) {
                    break;
                }
                dataset.compute(part -> {
                    if (part.getCopiedOriginalLabels() == null) {
                        part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
                    }
                    for (int j = 0; j < part.getLabels().length; ++j) {
                        double mdlAnswer = currComposition.apply(VectorUtils.of(part.getFeatures()[j]));
                        double originalLbVal = (Double)this.externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
                        part.getLabels()[j] = -this.loss.gradient(this.sampleSize, originalLbVal, mdlAnswer);
                    }
                });
                long startTs = System.currentTimeMillis();
                models.add(decisionTreeTrainer.fit(dataset));
                double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
                this.environment.logger(this.getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        this.compositionWeights = Arrays.copyOf(this.compositionWeights, models.size());
        return models;
    }
}

