/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.math.decompositions;

import java.io.Serializable;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.exceptions.NoDataException;
import org.apache.ignite.ml.math.exceptions.NullArgumentException;
import org.apache.ignite.ml.math.functions.Functions;
import org.apache.ignite.ml.math.util.MatrixUtil;

public class QRDSolver
implements Serializable {
    private final Matrix q;
    private final Matrix r;

    public QRDSolver(Matrix q, Matrix r) {
        this.q = q;
        this.r = r;
    }

    public Matrix solve(Matrix mtx) {
        if (mtx.rowSize() != this.q.rowSize()) {
            throw new IllegalArgumentException("Matrix row dimensions must agree.");
        }
        int cols = mtx.columnSize();
        Matrix x = MatrixUtil.like(this.r, this.r.columnSize(), cols);
        Matrix qt = this.q.transpose();
        Matrix y = qt.times(mtx);
        for (int k = Math.min(this.r.columnSize(), this.q.rowSize()) - 1; k >= 0; --k) {
            x.viewRow(k).map(y.viewRow(k), Functions.plusMult(1.0 / this.r.get(k, k)));
            if (k == 0) continue;
            Vector rCol = this.r.viewColumn(k).viewPart(0, k);
            for (int c = 0; c < cols; ++c) {
                y.viewColumn(c).viewPart(0, k).map(rCol, Functions.plusMult(-x.get(k, c)));
            }
        }
        return x;
    }

    public Vector solve(Vector vec) {
        if (vec == null) {
            throw new NullArgumentException();
        }
        if (vec.size() == 0) {
            throw new NoDataException();
        }
        Matrix res = this.solve(vec.likeMatrix(vec.size(), 1).assignColumn(0, vec));
        return vec.like(res.rowSize()).assign(res.viewColumn(0));
    }

    public Matrix calculateHat() {
        Matrix augI = MatrixUtil.like(this.q, this.q.columnSize(), this.q.columnSize());
        int n = augI.columnSize();
        int p = this.r.columnSize();
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                if (i == j && i < p) {
                    augI.setX(i, j, 1.0);
                    continue;
                }
                augI.setX(i, j, 0.0);
            }
        }
        return this.q.times(augI).times(this.q.transpose());
    }

    public Matrix calculateBetaVariance(int p) {
        Matrix rAug = MatrixUtil.copy(this.r.viewPart(0, p, 0, p));
        Matrix rInv = rAug.inverse();
        return rInv.times(rInv.transpose());
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        QRDSolver solver = (QRDSolver)o;
        return this.q.equals(solver.q) && this.r.equals(solver.r);
    }

    public int hashCode() {
        int res = this.q.hashCode();
        res = 31 * res + this.r.hashCode();
        return res;
    }

    public String toString() {
        return String.format("QRD Solver(%d x %d)", this.q.rowSize(), this.r.columnSize());
    }
}

