/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.dataset.primitive;

import com.github.fommil.netlib.BLAS;
import java.io.Serializable;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.DatasetWrapper;
import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData;

public class SimpleDataset<C extends Serializable>
extends DatasetWrapper<C, SimpleDatasetData> {
    private static final BLAS blas = BLAS.getInstance();

    public SimpleDataset(Dataset<C, SimpleDatasetData> delegate) {
        super(delegate);
    }

    public double[] mean() {
        ValueWithCount res = (ValueWithCount)this.delegate.compute((data, partIdx) -> {
            double[] features = data.getFeatures();
            int rows = data.getRows();
            int cols = features.length / rows;
            double[] y = new double[cols];
            for (int col = 0; col < cols; ++col) {
                for (int j = col * rows; j < (col + 1) * rows; ++j) {
                    int n = col;
                    y[n] = y[n] + features[j];
                }
            }
            return new ValueWithCount<double[]>(y, rows);
        }, (a, b) -> a == null ? b : (b == null ? a : new ValueWithCount<double[]>(SimpleDataset.sum((double[])((ValueWithCount)a).val, (double[])((ValueWithCount)b).val), ((ValueWithCount)a).cnt + ((ValueWithCount)b).cnt)));
        if (res != null) {
            blas.dscal(((double[])res.val).length, 1.0 / (double)res.cnt, (double[])res.val, 1);
            return (double[])res.val;
        }
        return null;
    }

    public double[] std() {
        double[] mean = this.mean();
        ValueWithCount res = (ValueWithCount)this.delegate.compute(data -> {
            double[] features = data.getFeatures();
            int rows = data.getRows();
            int cols = features.length / rows;
            double[] y = new double[cols];
            for (int col = 0; col < cols; ++col) {
                for (int j = col * rows; j < (col + 1) * rows; ++j) {
                    int n = col;
                    y[n] = y[n] + Math.pow(features[j] - mean[col], 2.0);
                }
            }
            return new ValueWithCount<double[]>(y, rows);
        }, (a, b) -> a == null ? b : (b == null ? a : new ValueWithCount<double[]>(SimpleDataset.sum((double[])((ValueWithCount)a).val, (double[])((ValueWithCount)b).val), ((ValueWithCount)a).cnt + ((ValueWithCount)b).cnt)));
        if (res != null) {
            blas.dscal(((double[])res.val).length, 1.0 / (double)res.cnt, (double[])res.val, 1);
            for (int i = 0; i < ((double[])res.val).length; ++i) {
                ((double[])((ValueWithCount)res).val)[i] = Math.sqrt(((double[])res.val)[i]);
            }
            return (double[])res.val;
        }
        return null;
    }

    public double[][] cov() {
        double[] mean = this.mean();
        ValueWithCount res = (ValueWithCount)this.delegate.compute(data -> {
            double[] features = data.getFeatures();
            int rows = data.getRows();
            int cols = features.length / rows;
            double[][] y = new double[cols][cols];
            for (int firstCol = 0; firstCol < cols; ++firstCol) {
                for (int secondCol = 0; secondCol < cols; ++secondCol) {
                    for (int k = 0; k < rows; ++k) {
                        double firstVal = features[rows * firstCol + k];
                        double secondVal = features[rows * secondCol + k];
                        double[] dArray = y[firstCol];
                        int n = secondCol;
                        dArray[n] = dArray[n] + (firstVal - mean[firstCol]) * (secondVal - mean[secondCol]);
                    }
                }
            }
            return new ValueWithCount<double[][]>(y, rows);
        }, (a, b) -> a == null ? b : (b == null ? a : new ValueWithCount<double[][]>(SimpleDataset.sum((double[][])((ValueWithCount)a).val, (double[][])((ValueWithCount)b).val), ((ValueWithCount)a).cnt + ((ValueWithCount)b).cnt)));
        return res != null ? SimpleDataset.scale((double[][])res.val, 1.0 / (double)res.cnt) : (double[][])null;
    }

    public double[][] corr() {
        double[][] cov = this.cov();
        double[] std = this.std();
        for (int i = 0; i < cov.length; ++i) {
            for (int j = 0; j < cov[0].length; ++j) {
                double[] dArray = cov[i];
                int n = j;
                dArray[n] = dArray[n] / (std[i] * std[j]);
            }
        }
        return cov;
    }

    private static double[] sum(double[] a, double[] b) {
        for (int i = 0; i < a.length; ++i) {
            int n = i;
            a[n] = a[n] + b[i];
        }
        return a;
    }

    private static double[][] sum(double[][] a, double[][] b) {
        for (int i = 0; i < a.length; ++i) {
            for (int j = 0; j < a[i].length; ++j) {
                double[] dArray = a[i];
                int n = j;
                dArray[n] = dArray[n] + b[i][j];
            }
        }
        return a;
    }

    private static double[][] scale(double[][] a, double alpha) {
        for (int i = 0; i < a.length; ++i) {
            int j = 0;
            while (j < a[i].length) {
                double[] dArray = a[i];
                int n = j++;
                dArray[n] = dArray[n] * alpha;
            }
        }
        return a;
    }

    private static class ValueWithCount<V> {
        private final V val;
        private final int cnt;

        ValueWithCount(V val, int cnt) {
            this.val = val;
            this.cnt = cnt;
        }
    }
}

