/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.impurity.gini;

import java.util.Arrays;
import java.util.Map;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.util.StepFunction;

public class GiniImpurityMeasureCalculator
implements ImpurityMeasureCalculator<GiniImpurityMeasure> {
    private static final long serialVersionUID = -522995134128519679L;
    private final Map<Double, Integer> lbEncoder;

    public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder) {
        this.lbEncoder = lbEncoder;
    }

    @Override
    public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data) {
        double[][] features = data.getFeatures();
        double[] labels = data.getLabels();
        if (features.length > 0) {
            StepFunction[] res = new StepFunction[features[0].length];
            for (int col = 0; col < res.length; ++col) {
                int i;
                data.sort(col);
                double[] x = new double[features.length + 1];
                GiniImpurityMeasure[] y = new GiniImpurityMeasure[features.length + 1];
                int xPtr = 0;
                int yPtr = 0;
                long[] left = new long[this.lbEncoder.size()];
                long[] right = new long[this.lbEncoder.size()];
                for (i = 0; i < labels.length; ++i) {
                    int n = this.getLabelCode(labels[i]);
                    right[n] = right[n] + 1L;
                }
                x[xPtr++] = Double.NEGATIVE_INFINITY;
                y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(right, right.length));
                for (i = 0; i < features.length; ++i) {
                    int n = this.getLabelCode(labels[i]);
                    left[n] = left[n] + 1L;
                    int n2 = this.getLabelCode(labels[i]);
                    right[n2] = right[n2] - 1L;
                    if (i < features.length - 1 && features[i + 1][col] == features[i][col]) continue;
                    x[xPtr++] = features[i][col];
                    y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(right, right.length));
                }
                res[col] = new StepFunction(Arrays.copyOf(x, xPtr), (ImpurityMeasure[])Arrays.copyOf(y, yPtr));
            }
            return res;
        }
        return null;
    }

    int getLabelCode(double lb) {
        Integer code = this.lbEncoder.get(lb);
        assert (code != null) : "Can't find code for label " + lb;
        return code;
    }
}

