/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.selection.scoring.evaluator.aggregator;

import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator;
import org.apache.ignite.ml.selection.scoring.evaluator.context.EmptyContext;
import org.apache.ignite.ml.structures.LabeledVector;

public class RegressionMetricStatsAggregator
implements MetricStatsAggregator<Double, EmptyContext<Double>, RegressionMetricStatsAggregator> {
    private static final long serialVersionUID = -2459352313996869235L;
    private long n;
    private double absoluteError = Double.NaN;
    private double rss = Double.NaN;
    private double sumOfYs = Double.NaN;
    private double sumOfSquaredYs = Double.NaN;

    public RegressionMetricStatsAggregator() {
    }

    public RegressionMetricStatsAggregator(long n, double absoluteError, double rss, double sumOfYs, double sumOfSquaredYs) {
        this.n = n;
        this.absoluteError = absoluteError;
        this.rss = rss;
        this.sumOfYs = sumOfYs;
        this.sumOfSquaredYs = sumOfSquaredYs;
    }

    @Override
    public void aggregate(IgniteModel<Vector, Double> model, LabeledVector<Double> vector) {
        ++this.n;
        Double prediction = (Double)model.predict(vector.features());
        Double truth = vector.label();
        A.notNull((Object)(truth != null ? 1 : 0), (String)"Test set mustn't contain null labels");
        A.notNull((Object)(prediction != null ? 1 : 0), (String)"Model mustn't return null answers");
        double error = truth - prediction;
        this.absoluteError = this.sum(Math.abs(error), this.absoluteError);
        this.rss = this.sum(Math.pow(error, 2.0), this.rss);
        this.sumOfYs = this.sum(truth, this.sumOfYs);
        this.sumOfSquaredYs = this.sum(Math.pow(truth, 2.0), this.sumOfSquaredYs);
    }

    @Override
    public RegressionMetricStatsAggregator mergeWith(RegressionMetricStatsAggregator other) {
        long n = this.n + other.n;
        double absoluteError = this.sum(this.absoluteError, other.absoluteError);
        double squaredError = this.sum(this.rss, other.rss);
        double sumOfYs = this.sum(this.sumOfYs, other.sumOfYs);
        double sumOfSquaredYs = this.sum(this.sumOfSquaredYs, other.sumOfSquaredYs);
        return new RegressionMetricStatsAggregator(n, absoluteError, squaredError, sumOfYs, sumOfSquaredYs);
    }

    @Override
    public EmptyContext createInitializedContext() {
        return new EmptyContext();
    }

    @Override
    public void initByContext(EmptyContext context) {
    }

    public double getMAE() {
        if (Double.isNaN(this.absoluteError)) {
            return Double.NaN;
        }
        return this.absoluteError / (double)Math.max(this.n, 1L);
    }

    public double getMSE() {
        return this.rss / (double)Math.max(this.n, 1L);
    }

    public double ysRss() {
        return this.ysVariance() * (double)Math.max(this.n, 1L);
    }

    public double ysVariance() {
        if (Double.isNaN(this.sumOfSquaredYs)) {
            return Double.NaN;
        }
        return this.sumOfSquaredYs / (double)Math.max(this.n, 1L) - Math.pow(this.sumOfYs / (double)Math.max(this.n, 1L), 2.0);
    }

    public double getRss() {
        return this.rss;
    }

    private double sum(double v1, double v2) {
        if (Double.isNaN(v1)) {
            return v2;
        }
        if (Double.isNaN(v2)) {
            return v1;
        }
        return v1 + v2;
    }
}

