/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.boosting.convergence;

import java.io.Serializable;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;

public abstract class ConvergenceChecker<K, V>
implements Serializable {
    private static final long serialVersionUID = 710762134746674105L;
    private long sampleSize;
    private IgniteFunction<Double, Double> externalLbToInternalMapping;
    private Loss loss;
    private IgniteBiFunction<K, V, Vector> featureExtractor;
    private IgniteBiFunction<K, V, Double> lbExtractor;
    private double precision;

    public ConvergenceChecker(long sampleSize, IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, double precision) {
        assert (precision < 1.0 && precision >= 0.0);
        this.sampleSize = sampleSize;
        this.externalLbToInternalMapping = externalLbToInternalMapping;
        this.loss = loss;
        this.featureExtractor = featureExtractor;
        this.lbExtractor = lbExtractor;
        this.precision = precision;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public boolean isConverged(DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
        try (Dataset dataset = datasetBuilder.build(new EmptyContextBuilder(), new FeatureMatrixWithLabelsOnHeapDataBuilder(this.featureExtractor, this.lbExtractor));){
            boolean bl = this.isConverged(dataset, currMdl);
            return bl;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public boolean isConverged(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, ModelsComposition currMdl) {
        Double error = this.computeMeanErrorOnDataset(dataset, currMdl);
        return error < this.precision || error.isNaN();
    }

    public abstract Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> var1, ModelsComposition var2);

    public double computeError(Vector features, Double answer, ModelsComposition currMdl) {
        Double realAnswer = (Double)this.externalLbToInternalMapping.apply(answer);
        Double mdlAnswer = currMdl.apply(features);
        return -this.loss.gradient(this.sampleSize, realAnswer, mdlAnswer);
    }
}

