/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.gpu.context;

import jcuda.Pointer;
import jcuda.jcublas.cublasHandle;
import jcuda.jcusparse.JCusparse;
import jcuda.jcusparse.cusparseHandle;
import jcuda.jcusparse.cusparseMatDescr;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysml.utils.GPUStatistics;

public class CSRPointer {
    private static final Log LOG = LogFactory.getLog((String)CSRPointer.class.getName());
    private static final double ULTRA_SPARSITY_TURN_POINT = 4.0E-5;
    public static cusparseMatDescr matrixDescriptor;
    private final GPUContext gpuContext;
    public long nnz;
    public Pointer val;
    public Pointer rowPtr;
    public Pointer colInd;
    public cusparseMatDescr descr;

    private CSRPointer(GPUContext gCtx) {
        this.gpuContext = gCtx;
        this.val = new Pointer();
        this.rowPtr = new Pointer();
        this.colInd = new Pointer();
        this.allocateMatDescrPointer();
    }

    private static long getDoubleSizeOf(long numElems) {
        return numElems * 8L;
    }

    private static long getIntSizeOf(long numElems) {
        return numElems * 4L;
    }

    public static int toIntExact(long l) throws DMLRuntimeException {
        if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) {
            throw new DMLRuntimeException("Cannot be cast to int:" + l);
        }
        return (int)l;
    }

    public static cusparseMatDescr getDefaultCuSparseMatrixDescriptor() {
        if (matrixDescriptor == null) {
            matrixDescriptor = new cusparseMatDescr();
            JCusparse.cusparseCreateMatDescr((cusparseMatDescr)matrixDescriptor);
            JCusparse.cusparseSetMatType((cusparseMatDescr)matrixDescriptor, (int)0);
            JCusparse.cusparseSetMatIndexBase((cusparseMatDescr)matrixDescriptor, (int)0);
        }
        return matrixDescriptor;
    }

    public static long estimateSize(long nnz2, long rows) {
        long sizeofValArray = CSRPointer.getDoubleSizeOf(nnz2);
        long sizeofRowPtrArray = CSRPointer.getIntSizeOf(rows + 1L);
        long sizeofColIndArray = CSRPointer.getIntSizeOf(nnz2);
        long sizeofDescr = CSRPointer.getIntSizeOf(4L);
        long tot = sizeofValArray + sizeofRowPtrArray + sizeofColIndArray + sizeofDescr;
        return tot;
    }

    public static void copyToDevice(CSRPointer dest, int rows, long nnz, int[] rowPtr, int[] colInd, double[] values) throws DMLRuntimeException {
        CSRPointer r = dest;
        long t0 = 0L;
        if (DMLScript.STATISTICS) {
            t0 = System.nanoTime();
        }
        r.nnz = nnz;
        if (rows < 0) {
            throw new DMLRuntimeException("Incorrect input parameter: rows=" + rows);
        }
        if (nnz < 0L) {
            throw new DMLRuntimeException("Incorrect input parameter: nnz=" + nnz);
        }
        if (rowPtr.length < rows + 1) {
            throw new DMLRuntimeException("The length of rowPtr needs to be greater than or equal to " + (rows + 1));
        }
        if ((long)colInd.length < nnz) {
            throw new DMLRuntimeException("The length of colInd needs to be greater than or equal to " + nnz);
        }
        if ((long)values.length < nnz) {
            throw new DMLRuntimeException("The length of values needs to be greater than or equal to " + nnz);
        }
        JCuda.cudaMemcpy((Pointer)r.rowPtr, (Pointer)Pointer.to((int[])rowPtr), (long)CSRPointer.getIntSizeOf(rows + 1), (int)1);
        JCuda.cudaMemcpy((Pointer)r.colInd, (Pointer)Pointer.to((int[])colInd), (long)CSRPointer.getIntSizeOf(nnz), (int)1);
        JCuda.cudaMemcpy((Pointer)r.val, (Pointer)Pointer.to((double[])values), (long)CSRPointer.getDoubleSizeOf(nnz), (int)1);
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaToDevTime.add(System.nanoTime() - t0);
        }
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaToDevCount.add(3L);
        }
    }

    public static void copyToHost(CSRPointer src, int rows, long nnz, int[] rowPtr, int[] colInd, double[] values) {
        CSRPointer r = src;
        long t0 = 0L;
        if (DMLScript.STATISTICS) {
            t0 = System.nanoTime();
        }
        JCuda.cudaMemcpy((Pointer)Pointer.to((int[])rowPtr), (Pointer)r.rowPtr, (long)CSRPointer.getIntSizeOf(rows + 1), (int)2);
        JCuda.cudaMemcpy((Pointer)Pointer.to((int[])colInd), (Pointer)r.colInd, (long)CSRPointer.getIntSizeOf(nnz), (int)2);
        JCuda.cudaMemcpy((Pointer)Pointer.to((double[])values), (Pointer)r.val, (long)CSRPointer.getDoubleSizeOf(nnz), (int)2);
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaFromDevTime.add(System.nanoTime() - t0);
        }
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaFromDevCount.add(3L);
        }
    }

    public static CSRPointer allocateForDgeam(GPUContext gCtx, cusparseHandle handle, CSRPointer A, CSRPointer B, int m, int n) throws DMLRuntimeException {
        if (A.nnz >= Integer.MAX_VALUE || B.nnz >= Integer.MAX_VALUE) {
            throw new DMLRuntimeException("Number of non zeroes is larger than supported by cuSparse");
        }
        CSRPointer C = new CSRPointer(gCtx);
        CSRPointer.step1AllocateRowPointers(gCtx, handle, C, m);
        CSRPointer.step2GatherNNZGeam(gCtx, handle, A, B, C, m, n);
        CSRPointer.step3AllocateValNInd(gCtx, handle, C);
        return C;
    }

    public static CSRPointer allocateForMatrixMultiply(GPUContext gCtx, cusparseHandle handle, CSRPointer A, int transA, CSRPointer B, int transB, int m, int n, int k) throws DMLRuntimeException {
        CSRPointer C = new CSRPointer(gCtx);
        CSRPointer.step1AllocateRowPointers(gCtx, handle, C, m);
        CSRPointer.step2GatherNNZGemm(gCtx, handle, A, transA, B, transB, C, m, n, k);
        CSRPointer.step3AllocateValNInd(gCtx, handle, C);
        return C;
    }

    public static CSRPointer allocateEmpty(GPUContext gCtx, long nnz2, long rows) throws DMLRuntimeException {
        LOG.trace((Object)("GPU : allocateEmpty from CSRPointer with nnz=" + nnz2 + " and rows=" + rows + ", GPUContext=" + gCtx));
        assert (nnz2 > -1L) : "Incorrect usage of internal API, number of non zeroes is less than 0 when trying to allocate sparse data on GPU";
        CSRPointer r = new CSRPointer(gCtx);
        r.nnz = nnz2;
        if (nnz2 == 0L) {
            return r;
        }
        gCtx.ensureFreeSpace(CSRPointer.getDoubleSizeOf(nnz2) + CSRPointer.getIntSizeOf(rows + 1L) + CSRPointer.getIntSizeOf(nnz2));
        r.val = gCtx.allocate(null, CSRPointer.getDoubleSizeOf(nnz2));
        r.rowPtr = gCtx.allocate(null, CSRPointer.getIntSizeOf(rows + 1L));
        r.colInd = gCtx.allocate(null, CSRPointer.getIntSizeOf(nnz2));
        return r;
    }

    private static void step1AllocateRowPointers(GPUContext gCtx, cusparseHandle handle, CSRPointer C, int rowsC) throws DMLRuntimeException {
        LOG.trace((Object)("GPU : step1AllocateRowPointers, GPUContext=" + gCtx));
        JCusparse.cusparseSetPointerMode((cusparseHandle)handle, (int)0);
        C.rowPtr = gCtx.allocate(CSRPointer.getIntSizeOf((long)rowsC + 1L));
    }

    private static void step2GatherNNZGeam(GPUContext gCtx, cusparseHandle handle, CSRPointer A, CSRPointer B, CSRPointer C, int m, int n) throws DMLRuntimeException {
        LOG.trace((Object)("GPU : step2GatherNNZGeam for DGEAM, GPUContext=" + gCtx));
        int[] CnnzArray = new int[]{-1};
        JCusparse.cusparseXcsrgeamNnz((cusparseHandle)handle, (int)m, (int)n, (cusparseMatDescr)A.descr, (int)CSRPointer.toIntExact(A.nnz), (Pointer)A.rowPtr, (Pointer)A.colInd, (cusparseMatDescr)B.descr, (int)CSRPointer.toIntExact(B.nnz), (Pointer)B.rowPtr, (Pointer)B.colInd, (cusparseMatDescr)C.descr, (Pointer)C.rowPtr, (Pointer)Pointer.to((int[])CnnzArray));
        if (CnnzArray[0] != -1) {
            C.nnz = CnnzArray[0];
        } else {
            int[] baseArray = new int[]{0};
            JCuda.cudaMemcpy((Pointer)Pointer.to((int[])CnnzArray), (Pointer)C.rowPtr.withByteOffset(CSRPointer.getIntSizeOf(m)), (long)CSRPointer.getIntSizeOf(1L), (int)2);
            JCuda.cudaMemcpy((Pointer)Pointer.to((int[])baseArray), (Pointer)C.rowPtr, (long)CSRPointer.getIntSizeOf(1L), (int)2);
            C.nnz = CnnzArray[0] - baseArray[0];
        }
    }

    private static void step2GatherNNZGemm(GPUContext gCtx, cusparseHandle handle, CSRPointer A, int transA, CSRPointer B, int transB, CSRPointer C, int m, int n, int k) throws DMLRuntimeException {
        LOG.trace((Object)("GPU : step2GatherNNZGemm for DGEMM, GPUContext=" + gCtx));
        int[] CnnzArray = new int[]{-1};
        if (A.nnz >= Integer.MAX_VALUE || B.nnz >= Integer.MAX_VALUE) {
            throw new DMLRuntimeException("Number of non zeroes is larger than supported by cuSparse");
        }
        JCusparse.cusparseXcsrgemmNnz((cusparseHandle)handle, (int)transA, (int)transB, (int)m, (int)n, (int)k, (cusparseMatDescr)A.descr, (int)CSRPointer.toIntExact(A.nnz), (Pointer)A.rowPtr, (Pointer)A.colInd, (cusparseMatDescr)B.descr, (int)CSRPointer.toIntExact(B.nnz), (Pointer)B.rowPtr, (Pointer)B.colInd, (cusparseMatDescr)C.descr, (Pointer)C.rowPtr, (Pointer)Pointer.to((int[])CnnzArray));
        if (CnnzArray[0] != -1) {
            C.nnz = CnnzArray[0];
        } else {
            int[] baseArray = new int[]{0};
            JCuda.cudaMemcpy((Pointer)Pointer.to((int[])CnnzArray), (Pointer)C.rowPtr.withByteOffset(CSRPointer.getIntSizeOf(m)), (long)CSRPointer.getIntSizeOf(1L), (int)2);
            JCuda.cudaMemcpy((Pointer)Pointer.to((int[])baseArray), (Pointer)C.rowPtr, (long)CSRPointer.getIntSizeOf(1L), (int)2);
            C.nnz = CnnzArray[0] - baseArray[0];
        }
    }

    private static void step3AllocateValNInd(GPUContext gCtx, cusparseHandle handle, CSRPointer C) throws DMLRuntimeException {
        LOG.trace((Object)("GPU : step3AllocateValNInd, GPUContext=" + gCtx));
        C.val = gCtx.allocate(null, CSRPointer.getDoubleSizeOf(C.nnz));
        C.colInd = gCtx.allocate(null, CSRPointer.getIntSizeOf(C.nnz));
    }

    public CSRPointer clone(int rows) throws DMLRuntimeException {
        CSRPointer me = this;
        CSRPointer that = new CSRPointer(me.getGPUContext());
        that.allocateMatDescrPointer();
        long totalSize = CSRPointer.estimateSize(me.nnz, rows);
        that.gpuContext.ensureFreeSpace(totalSize);
        that.nnz = me.nnz;
        that.val = this.allocate(that.nnz * 8L);
        that.rowPtr = this.allocate(rows * 8);
        that.colInd = this.allocate(that.nnz * 8L);
        JCuda.cudaMemcpy((Pointer)that.val, (Pointer)me.val, (long)(that.nnz * 8L), (int)3);
        JCuda.cudaMemcpy((Pointer)that.rowPtr, (Pointer)me.rowPtr, (long)(rows * 8), (int)3);
        JCuda.cudaMemcpy((Pointer)that.colInd, (Pointer)me.colInd, (long)(that.nnz * 8L), (int)3);
        return that;
    }

    private Pointer allocate(long size) throws DMLRuntimeException {
        return this.getGPUContext().allocate(size);
    }

    private void cudaFreeHelper(Pointer toFree, boolean eager) throws DMLRuntimeException {
        this.getGPUContext().cudaFreeHelper(toFree, eager);
    }

    private GPUContext getGPUContext() {
        return this.gpuContext;
    }

    public boolean isUltraSparse(int rows, int cols) {
        double sp = (double)this.nnz / (double)rows / (double)cols;
        return sp < 4.0E-5;
    }

    private void allocateMatDescrPointer() {
        this.descr = CSRPointer.getDefaultCuSparseMatrixDescriptor();
    }

    public Pointer toColumnMajorDenseMatrix(cusparseHandle cusparseHandle2, cublasHandle cublasHandle2, int rows, int cols) throws DMLRuntimeException {
        LOG.trace((Object)("GPU : sparse -> column major dense (inside CSRPointer) on " + this + ", GPUContext=" + this.getGPUContext()));
        long size = (long)rows * CSRPointer.getDoubleSizeOf(cols);
        Pointer A = this.allocate(size);
        if (this.val != null && this.rowPtr != null && this.colInd != null && this.nnz > 0L) {
            JCusparse.cusparseDcsr2dense((cusparseHandle)cusparseHandle2, (int)rows, (int)cols, (cusparseMatDescr)this.descr, (Pointer)this.val, (Pointer)this.rowPtr, (Pointer)this.colInd, (Pointer)A, (int)rows);
        } else {
            LOG.debug((Object)"in CSRPointer, the values array, row pointers array or column indices array was null");
        }
        return A;
    }

    public void deallocate() throws DMLRuntimeException {
        this.deallocate(false);
    }

    public void deallocate(boolean eager) throws DMLRuntimeException {
        if (this.nnz > 0L) {
            this.cudaFreeHelper(this.val, eager);
            this.cudaFreeHelper(this.rowPtr, eager);
            this.cudaFreeHelper(this.colInd, eager);
            this.val = null;
            this.rowPtr = null;
            this.colInd = null;
        }
    }

    public String toString() {
        return "CSRPointer{nnz=" + this.nnz + '}';
    }
}

