/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.math.decompositions;

import org.apache.ignite.ml.math.Destroyable;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.exceptions.CardinalityException;
import org.apache.ignite.ml.math.exceptions.NonPositiveDefiniteMatrixException;
import org.apache.ignite.ml.math.exceptions.NonSymmetricMatrixException;
import org.apache.ignite.ml.math.util.MatrixUtil;

public class CholeskyDecomposition
implements Destroyable {
    public static final double DFLT_REL_SYMMETRY_THRESHOLD = 1.0E-15;
    public static final double DFLT_ABS_POSITIVITY_THRESHOLD = 1.0E-10;
    private double[][] lTData;
    private Matrix cachedL;
    private Matrix cachedLT;
    private Matrix origin;

    public CholeskyDecomposition(Matrix mtx) {
        this(mtx, 1.0E-15, 1.0E-10);
    }

    public CholeskyDecomposition(Matrix mtx, double relSymmetryThreshold, double absPositivityThreshold) {
        int i;
        assert (mtx != null);
        if (mtx.columnSize() != mtx.rowSize()) {
            throw new CardinalityException(mtx.rowSize(), mtx.columnSize());
        }
        this.origin = mtx;
        int order = mtx.rowSize();
        this.lTData = this.toDoubleArr(mtx);
        this.cachedL = null;
        this.cachedLT = null;
        for (i = 0; i < order; ++i) {
            double[] lI = this.lTData[i];
            for (int j = i + 1; j < order; ++j) {
                double[] lJ = this.lTData[j];
                double lIJ = lI[j];
                double lJI = lJ[i];
                double maxDelta = relSymmetryThreshold * Math.max(Math.abs(lIJ), Math.abs(lJI));
                if (Math.abs(lIJ - lJI) > maxDelta) {
                    throw new NonSymmetricMatrixException(i, j, relSymmetryThreshold);
                }
                lJ[i] = 0.0;
            }
        }
        for (i = 0; i < order; ++i) {
            double[] ltI = this.lTData[i];
            if (ltI[i] <= absPositivityThreshold) {
                throw new NonPositiveDefiniteMatrixException(ltI[i], i, absPositivityThreshold);
            }
            ltI[i] = Math.sqrt(ltI[i]);
            double inverse = 1.0 / ltI[i];
            for (int q = order - 1; q > i; --q) {
                int n = q;
                ltI[n] = ltI[n] * inverse;
                double[] ltQ = this.lTData[q];
                for (int p = q; p < order; ++p) {
                    int n2 = p;
                    ltQ[n2] = ltQ[n2] - ltI[q] * ltI[p];
                }
            }
        }
    }

    @Override
    public void destroy() {
        if (this.cachedL != null) {
            this.cachedL.destroy();
        }
        if (this.cachedLT != null) {
            this.cachedLT.destroy();
        }
    }

    public Matrix getL() {
        if (this.cachedL == null) {
            this.cachedL = this.getLT().transpose();
        }
        return this.cachedL;
    }

    public Matrix getLT() {
        if (this.cachedLT == null) {
            Matrix like = MatrixUtil.like(this.origin, this.origin.rowSize(), this.origin.columnSize());
            like.assign(this.lTData);
            this.cachedLT = like;
        }
        return this.cachedLT;
    }

    public double getDeterminant() {
        double determinant = 1.0;
        for (int i = 0; i < this.lTData.length; ++i) {
            double lTii = this.lTData[i][i];
            determinant *= lTii * lTii;
        }
        return determinant;
    }

    public Vector solve(Vector b) {
        int j;
        int m = this.lTData.length;
        if (b.size() != m) {
            throw new CardinalityException(b.size(), m);
        }
        double[] x = b.getStorage().data();
        for (j = 0; j < m; ++j) {
            double[] lJ = this.lTData[j];
            int n = j;
            x[n] = x[n] / lJ[j];
            double xJ = x[j];
            for (int i = j + 1; i < m; ++i) {
                int n2 = i;
                x[n2] = x[n2] - xJ * lJ[i];
            }
        }
        for (j = m - 1; j >= 0; --j) {
            int n = j;
            x[n] = x[n] / this.lTData[j][j];
            double xJ = x[j];
            for (int i = 0; i < j; ++i) {
                int n3 = i;
                x[n3] = x[n3] - xJ * this.lTData[i][j];
            }
        }
        return MatrixUtil.likeVector(this.origin, m).assign(x);
    }

    public Matrix solve(Matrix b) {
        int j;
        int m = this.lTData.length;
        if (b.rowSize() != m) {
            throw new CardinalityException(b.rowSize(), m);
        }
        int nColB = b.columnSize();
        double[][] x = MatrixUtil.unflatten(b.getStorage().data(), b.columnSize(), b.getStorage().storageMode());
        for (j = 0; j < m; ++j) {
            double[] lJ = this.lTData[j];
            double lJJ = lJ[j];
            double[] xJ = x[j];
            int k = 0;
            while (k < nColB) {
                int n = k++;
                xJ[n] = xJ[n] / lJJ;
            }
            for (int i = j + 1; i < m; ++i) {
                double[] xI = x[i];
                double lJI = lJ[i];
                for (int k2 = 0; k2 < nColB; ++k2) {
                    int n = k2;
                    xI[n] = xI[n] - xJ[k2] * lJI;
                }
            }
        }
        for (j = m - 1; j >= 0; --j) {
            double lJJ = this.lTData[j][j];
            double[] xJ = x[j];
            int k = 0;
            while (k < nColB) {
                int n = k++;
                xJ[n] = xJ[n] / lJJ;
            }
            for (int i = 0; i < j; ++i) {
                double[] xI = x[i];
                double lIJ = this.lTData[i][j];
                for (int k3 = 0; k3 < nColB; ++k3) {
                    int n = k3;
                    xI[n] = xI[n] - xJ[k3] * lIJ;
                }
            }
        }
        return MatrixUtil.like(this.origin, m, b.columnSize()).assign(x);
    }

    private double[][] toDoubleArr(Matrix mtx) {
        if (mtx.isArrayBased()) {
            return MatrixUtil.unflatten(mtx.getStorage().data(), mtx.columnSize(), mtx.getStorage().storageMode());
        }
        double[][] res = new double[mtx.rowSize()][mtx.columnSize()];
        for (int row = 0; row < mtx.rowSize(); ++row) {
            for (int col = 0; col < mtx.columnSize(); ++col) {
                res[row][col] = mtx.get(row, col);
            }
        }
        return res;
    }
}

