/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.boosting.convergence.median;

import java.util.Arrays;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.Preprocessor;

public class MedianOfMedianConvergenceChecker<K, V>
extends ConvergenceChecker<K, V> {
    private static final long serialVersionUID = 4902502002933415287L;

    public MedianOfMedianConvergenceChecker(long sampleSize, IgniteFunction<Double, Double> lblMapping, Loss loss, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor, double precision) {
        super(sampleSize, lblMapping, loss, datasetBuilder, preprocessor, precision);
    }

    @Override
    public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, ModelsComposition mdl) {
        double[] medians = (double[])dataset.compute(data -> this.computeMedian(mdl, (FeatureMatrixWithLabelsOnHeapData)data), this::reduce);
        if (medians == null) {
            return Double.POSITIVE_INFINITY;
        }
        return this.getMedian(medians);
    }

    private double[] computeMedian(ModelsComposition mdl, FeatureMatrixWithLabelsOnHeapData data) {
        double[] errors = new double[data.getLabels().length];
        for (int i = 0; i < errors.length; ++i) {
            errors[i] = Math.abs(this.computeError(VectorUtils.of(data.getFeatures()[i]), data.getLabels()[i], mdl));
        }
        return new double[]{this.getMedian(errors)};
    }

    private double getMedian(double[] errors) {
        if (errors.length == 0) {
            return Double.POSITIVE_INFINITY;
        }
        Arrays.sort(errors);
        int middleIdx = (errors.length - 1) / 2;
        if (errors.length % 2 == 1) {
            return errors[middleIdx];
        }
        return (errors[middleIdx + 1] + errors[middleIdx]) / 2.0;
    }

    private double[] reduce(double[] left, double[] right) {
        if (left == null) {
            return right;
        }
        if (right == null) {
            return left;
        }
        double[] res = new double[left.length + right.length];
        System.arraycopy(left, 0, res, 0, left.length);
        System.arraycopy(right, 0, res, left.length, right.length);
        return res;
    }
}

