/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnActivationDescriptor;
import jcuda.jcudnn.cudnnConvolutionDescriptor;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnPoolingDescriptor;
import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNNConvolutionAlgorithm;
import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNNInputRowFetcher;
import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNNPoolingDescriptors;
import org.apache.sysml.utils.GPUStatistics;

public class LibMatrixCuDNN
extends LibMatrixCUDA {
    protected static int CONVOLUTION_PREFERENCE = 0;
    private static final Log LOG = LogFactory.getLog((String)LibMatrixCuDNN.class.getName());

    protected static cudnnHandle getCudnnHandle(GPUContext gCtx) throws DMLRuntimeException {
        return gCtx.getCudnnHandle();
    }

    public static void conv2dBiasAdd(GPUContext gCtx, String instName, MatrixObject image, MatrixObject bias, MatrixObject filter, MatrixObject output, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, double intermediateMemoryBudget) throws DMLRuntimeException {
        LibMatrixCuDNN.conv2d(gCtx, instName, image, filter, output, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, intermediateMemoryBudget);
        LibMatrixCuDNN.biasAdd(gCtx, instName, output, bias, output);
    }

    public static void conv2d(GPUContext gCtx, String instName, MatrixObject image, MatrixObject filter, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, double intermediateMemoryBudget) throws DMLRuntimeException {
        long CHW = C * H * W;
        long KPQ = K * P * Q;
        long CRS = C * R * S;
        long NCHW = (long)N * CHW;
        long NKPQ = (long)N * KPQ;
        long KCRS = (long)K * CRS;
        if (NCHW < maxNumElementsOfCuDNNTensor && NKPQ < maxNumElementsOfCuDNNTensor && KCRS < maxNumElementsOfCuDNNTensor) {
            double overhead = LibMatrixCuDNN.isInSparseFormat(gCtx, filter) ? (double)OptimizerUtils.estimateSizeExactSparsity((long)K, CRS, 1.0) : 0.0;
            double d = LibMatrixCuDNN.isInSparseFormat(gCtx, image) ? (double)OptimizerUtils.estimateSizeExactSparsity((long)N, CHW, 1.0) : 0.0;
            Pointer filterPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, filter, instName);
            Pointer dstPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, outputBlock, instName);
            long workspaceLimit = (long)(intermediateMemoryBudget - (overhead += d));
            int localN = overhead <= intermediateMemoryBudget ? N : 1;
            try (LibMatrixCuDNNConvolutionAlgorithm algo = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionForwardAlgorithm(gCtx, instName, localN, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, workspaceLimit);){
                if (localN == N) {
                    Pointer imagePointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
                    LibMatrixCuDNN.cudnnConv2d(gCtx, instName, imagePointer, filterPointer, dstPointer, algo);
                }
                try (LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);){
                    for (int n = 0; n < N; ++n) {
                        LibMatrixCuDNN.cudnnConv2d(gCtx, instName, imgFetcher.getNthRow(n), filterPointer, dstPointer.withByteOffset((long)n * KPQ * (long)sizeOfDataType), algo);
                    }
                }
            }
        } else {
            LibMatrixCuDNN.throwCuDNNDimensionError(N, CHW, K, CRS, N, KPQ);
        }
    }

    public static void softmax(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, String outputName) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : softmax, GPUContext=" + gCtx));
        }
        cudnnTensorDescriptor tensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(LibMatrixCuDNN.toInt(in1.getNumRows()), LibMatrixCuDNN.toInt(in1.getNumColumns()), 1, 1);
        Pointer srcPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, in1, instName);
        MatrixObject out = ec.getMatrixObject(outputName);
        ec.allocateGPUMatrixObject(outputName, in1.getNumRows(), in1.getNumColumns());
        out.getGPUObject(gCtx).allocateAndFillDense(0.0);
        Pointer dstPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, out, instName);
        JCudnn.cudnnSoftmaxForward((cudnnHandle)gCtx.getCudnnHandle(), (int)1, (int)1, (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)tensorDesc, (Pointer)srcPointer, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)tensorDesc, (Pointer)dstPointer);
        JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)tensorDesc);
    }

    private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) throws DMLRuntimeException {
        cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor((cudnnTensorDescriptor)tensorDescriptor);
        JCudnn.cudnnSetTensor4dDescriptor((cudnnTensorDescriptor)tensorDescriptor, (int)0, (int)LibMatrixCUDA.CUDNN_DATA_TYPE, (int)N, (int)C, (int)H, (int)W);
        return tensorDescriptor;
    }

    private static void throwCuDNNDimensionError(long dim1, long dim2, long dim3, long dim4) throws DMLRuntimeException {
        throw new DMLRuntimeException("The dimensions of input/output matrices is too large to execute a CuDNN kernel. Max CuDNN matrix size:" + maxNumElementsOfCuDNNTensor + ". Given input matrix dimensions: [" + dim1 + "," + dim2 + "]. Output dimension:  [" + dim3 + "," + dim4 + "].");
    }

    private static void throwCuDNNDimensionError(long dim1, long dim2, long dim3, long dim4, long dim5, long dim6) throws DMLRuntimeException {
        throw new DMLRuntimeException("The dimensions of input/output matrices is too large to execute a CuDNN kernel. Max CuDNN matrix size:" + maxNumElementsOfCuDNNTensor + ". Given input matrix dimensions: [" + dim1 + "," + dim2 + "], [" + dim3 + "," + dim4 + "]. Output dimension: [" + dim5 + "," + dim6 + "]");
    }

    private static void cudnnConv2d(GPUContext gCtx, String instName, Pointer image, Pointer filter, Pointer output, LibMatrixCuDNNConvolutionAlgorithm algo) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : conv2d, GPUContext=" + gCtx));
        }
        try {
            long t1 = 0L;
            if (DMLScript.FINEGRAINED_STATISTICS) {
                t1 = System.nanoTime();
            }
            int status = JCudnn.cudnnConvolutionForward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)algo.nchwTensorDesc, (Pointer)image, (cudnnFilterDescriptor)algo.filterDesc, (Pointer)filter, (cudnnConvolutionDescriptor)algo.convDesc, (int)algo.algo, (Pointer)algo.workSpace, (long)algo.sizeInBytes, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)algo.nkpqTensorDesc, (Pointer)output);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instName, "nncf", System.nanoTime() - t1);
            }
            if (status != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionForward: " + cudnnStatus.stringFor((int)status));
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
    }

    public static void conv2dBackwardFilter(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, double intermediateMemoryBudget) throws DMLRuntimeException {
        long CHW = C * H * W;
        long KPQ = K * P * Q;
        long CRS = C * R * S;
        long NCHW = (long)N * CHW;
        long NKPQ = (long)N * KPQ;
        long KCRS = (long)K * CRS;
        if (NCHW < maxNumElementsOfCuDNNTensor && NKPQ < maxNumElementsOfCuDNNTensor && KCRS < maxNumElementsOfCuDNNTensor) {
            Pointer dwPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, outputBlock, instName);
            double overhead = LibMatrixCuDNN.isInSparseFormat(gCtx, image) ? (double)OptimizerUtils.estimateSizeExactSparsity((long)N, CHW, 1.0) : 0.0;
            long workspaceLimit = (long)(intermediateMemoryBudget - (overhead += LibMatrixCuDNN.isInSparseFormat(gCtx, dout) ? (double)OptimizerUtils.estimateSizeExactSparsity((long)N, KPQ, 1.0) : 0.0));
            int localN = overhead <= intermediateMemoryBudget ? N : 1;
            try (LibMatrixCuDNNConvolutionAlgorithm algo = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionBackwardFilterAlgorithm(gCtx, instName, localN, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, workspaceLimit);){
                if (localN == N) {
                    Pointer imagePointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
                    Pointer doutPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, dout, instName);
                    LibMatrixCuDNN.cudnnConv2dBackwardFilter(gCtx, instName, imagePointer, doutPointer, dwPointer, algo);
                }
                try (LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
                     LibMatrixCuDNNInputRowFetcher doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout);){
                    Pointer tempdwPointer = gCtx.allocate(KCRS * (long)sizeOfDataType);
                    for (int n = 0; n < N; ++n) {
                        long t0 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                        JCuda.cudaMemset((Pointer)tempdwPointer, (int)0, (long)(KCRS * (long)sizeOfDataType));
                        if (DMLScript.FINEGRAINED_STATISTICS) {
                            GPUStatistics.maintainCPMiscTimes(instName, "az", System.nanoTime() - t0);
                        }
                        LibMatrixCuDNN.cudnnConv2dBackwardFilter(gCtx, instName, imgFetcher.getNthRow(n), doutFetcher.getNthRow(n), tempdwPointer, algo);
                        LibMatrixCuDNN.getCudaKernels(gCtx).launchKernel("inplace_add", ExecutionConfig.getConfigForSimpleMatrixOperations(K, LibMatrixCuDNN.toInt(CRS)), tempdwPointer, dwPointer, K, LibMatrixCuDNN.toInt(CRS));
                    }
                    gCtx.cudaFreeHelper(tempdwPointer, true);
                }
            }
        } else {
            LibMatrixCuDNN.throwCuDNNDimensionError(N, CHW, N, KPQ, K, CRS);
        }
    }

    private static void cudnnConv2dBackwardFilter(GPUContext gCtx, String instName, Pointer imagePointer, Pointer doutPointer, Pointer dwPointer, LibMatrixCuDNNConvolutionAlgorithm algo) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : conv2dBackwardFilter, GPUContext=" + gCtx));
        }
        try {
            long t1 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
            int status = JCudnn.cudnnConvolutionBackwardFilter((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)algo.nchwTensorDesc, (Pointer)imagePointer, (cudnnTensorDescriptor)algo.nkpqTensorDesc, (Pointer)doutPointer, (cudnnConvolutionDescriptor)algo.convDesc, (int)algo.algo, (Pointer)algo.workSpace, (long)algo.sizeInBytes, (Pointer)LibMatrixCuDNN.zero(), (cudnnFilterDescriptor)algo.filterDesc, (Pointer)dwPointer);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instName, "nncbf", System.nanoTime() - t1);
            }
            if (status != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionBackwardFilter: " + cudnnStatus.stringFor((int)status));
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
    }

    public static void conv2dBackwardData(GPUContext gCtx, String instName, MatrixObject filter, MatrixObject dout, MatrixObject output, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, double intermediateMemoryBudget) throws DMLRuntimeException {
        long CHW = C * H * W;
        long KPQ = K * P * Q;
        long CRS = C * R * S;
        long NCHW = (long)N * CHW;
        long NKPQ = (long)N * KPQ;
        long KCRS = (long)K * CRS;
        if (NCHW < maxNumElementsOfCuDNNTensor && NKPQ < maxNumElementsOfCuDNNTensor && KCRS < maxNumElementsOfCuDNNTensor) {
            double overhead = LibMatrixCuDNN.isInSparseFormat(gCtx, filter) ? (double)OptimizerUtils.estimateSizeExactSparsity((long)K, CRS, 1.0) : 0.0;
            double d = LibMatrixCuDNN.isInSparseFormat(gCtx, dout) ? (double)OptimizerUtils.estimateSizeExactSparsity((long)N, KPQ, 1.0) : 0.0;
            Pointer filterPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, filter, instName);
            Pointer dstPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, output, instName);
            long workspaceLimit = (long)(intermediateMemoryBudget - (overhead += d));
            int localN = overhead <= intermediateMemoryBudget ? N : 1;
            try (LibMatrixCuDNNConvolutionAlgorithm algo = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionBackwardDataAlgorithm(gCtx, instName, localN, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, workspaceLimit);){
                if (localN == N) {
                    Pointer doutPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, dout, instName);
                    LibMatrixCuDNN.cudnnConv2dBackwardData(gCtx, instName, filterPointer, doutPointer, dstPointer, algo);
                }
                try (LibMatrixCuDNNInputRowFetcher doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout);){
                    for (int n = 0; n < N; ++n) {
                        LibMatrixCuDNN.cudnnConv2dBackwardData(gCtx, instName, doutFetcher.getNthRow(n), filterPointer, dstPointer.withByteOffset((long)n * CHW * (long)sizeOfDataType), algo);
                    }
                }
            }
        } else {
            LibMatrixCuDNN.throwCuDNNDimensionError(N, CHW, N, KPQ, K, CRS);
        }
    }

    private static void cudnnConv2dBackwardData(GPUContext gCtx, String instName, Pointer w, Pointer dy, Pointer dx, LibMatrixCuDNNConvolutionAlgorithm algo) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : conv2dBackwardData, GPUContext=" + gCtx));
        }
        try {
            long t1 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
            int status = JCudnn.cudnnConvolutionBackwardData((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (Pointer)LibMatrixCuDNN.one(), (cudnnFilterDescriptor)algo.filterDesc, (Pointer)w, (cudnnTensorDescriptor)algo.nkpqTensorDesc, (Pointer)dy, (cudnnConvolutionDescriptor)algo.convDesc, (int)algo.algo, (Pointer)algo.workSpace, (long)algo.sizeInBytes, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)algo.nchwTensorDesc, (Pointer)dx);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instName, "nncbd", System.nanoTime() - t1);
            }
            if (status != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionBackwardData: " + cudnnStatus.stringFor((int)status));
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
    }

    public static void maxpooling(GPUContext gCtx, String instName, MatrixObject image, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, double intermediateMemoryBudget) throws DMLRuntimeException {
        long CHW = C * H * W;
        long CPQ = C * P * Q;
        long NCHW = (long)N * CHW;
        long NCPQ = (long)N * CPQ;
        if (NCHW < maxNumElementsOfCuDNNTensor && NCPQ < maxNumElementsOfCuDNNTensor) {
            long overhead = LibMatrixCuDNN.isInSparseFormat(gCtx, image) ? OptimizerUtils.estimateSizeExactSparsity((long)N, CHW, 1.0) : 0L;
            Pointer y = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, outputBlock, instName);
            if ((double)overhead <= intermediateMemoryBudget) {
                Pointer x = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
                LibMatrixCuDNN.cudnnMaxpooling(gCtx, instName, x, y, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
            } else {
                LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
                for (int n = 0; n < N; ++n) {
                    LibMatrixCuDNN.cudnnMaxpooling(gCtx, instName, imgFetcher.getNthRow(n), y.withByteOffset((long)n * CPQ * (long)sizeOfDataType), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
                }
                imgFetcher.close();
            }
        } else {
            LibMatrixCuDNN.throwCuDNNDimensionError(N, CHW, N, CPQ);
        }
    }

    private static void cudnnMaxpooling(GPUContext gCtx, String instName, Pointer x, Pointer y, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : performMaxpooling, GPUContext=" + gCtx));
        }
        try (LibMatrixCuDNNPoolingDescriptors desc = LibMatrixCuDNNPoolingDescriptors.cudnnMaxpoolingDescriptors(gCtx, instName, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);){
            long t1 = 0L;
            long t2 = 0L;
            if (DMLScript.FINEGRAINED_STATISTICS) {
                t1 = System.nanoTime();
            }
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instName, "nni", System.nanoTime() - t1);
            }
            if (DMLScript.FINEGRAINED_STATISTICS) {
                t2 = System.nanoTime();
            }
            int status = JCudnn.cudnnPoolingForward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnPoolingDescriptor)desc.poolingDesc, (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)desc.xDesc, (Pointer)x, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)desc.yDesc, (Pointer)y);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instName, "nnmf", System.nanoTime() - t2);
            }
            if (status != 0) {
                throw new DMLRuntimeException("Could not executed cudnnPoolingForward: " + cudnnStatus.stringFor((int)status));
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
    }

    public static void maxpoolingBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout, MatrixObject maxpoolOutput, MatrixObject outputBlock, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, double intermediateMemoryBudget) throws DMLRuntimeException {
        boolean isMaxPoolOutputProvided;
        long CHW = C * H * W;
        long CPQ = C * P * Q;
        long NCHW = (long)N * CHW;
        long NCPQ = (long)N * CPQ;
        boolean bl = isMaxPoolOutputProvided = maxpoolOutput != null;
        if (NCHW < maxNumElementsOfCuDNNTensor && NCPQ < maxNumElementsOfCuDNNTensor) {
            long overhead = LibMatrixCuDNN.isInSparseFormat(gCtx, image) ? OptimizerUtils.estimateSizeExactSparsity((long)N, CHW, 1.0) : 0L;
            overhead += LibMatrixCuDNN.isInSparseFormat(gCtx, dout) ? OptimizerUtils.estimateSizeExactSparsity((long)N, CPQ, 1.0) : 0L;
            Pointer dx = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, outputBlock, instName);
            if ((double)overhead <= intermediateMemoryBudget) {
                Pointer x = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
                Pointer dy = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, dout, instName);
                Pointer y = isMaxPoolOutputProvided ? LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, maxpoolOutput, instName) : null;
                LibMatrixCuDNN.cudnnMaxpoolingBackward(gCtx, instName, x, dy, y, dx, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
            } else {
                LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
                LibMatrixCuDNNInputRowFetcher doutFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout);
                LibMatrixCuDNNInputRowFetcher maxPoolOutFetcher = isMaxPoolOutputProvided ? new LibMatrixCuDNNInputRowFetcher(gCtx, instName, maxpoolOutput) : null;
                for (int n = 0; n < N; ++n) {
                    Pointer x = imgFetcher.getNthRow(n);
                    Pointer dy = doutFetcher.getNthRow(n);
                    Pointer y = isMaxPoolOutputProvided ? maxPoolOutFetcher.getNthRow(n) : null;
                    LibMatrixCuDNN.cudnnMaxpoolingBackward(gCtx, instName, x, dy, y, dx.withByteOffset((long)n * CHW * (long)sizeOfDataType), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
                }
                imgFetcher.close();
                doutFetcher.close();
                if (isMaxPoolOutputProvided) {
                    maxPoolOutFetcher.close();
                }
            }
        } else {
            LibMatrixCuDNN.throwCuDNNDimensionError(N, CHW, N, CPQ);
        }
    }

    private static void cudnnMaxpoolingBackward(GPUContext gCtx, String instName, Pointer x, Pointer dy, Pointer y, Pointer dx, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("GPU : maxpoolingBackward, GPUContext=" + gCtx));
        }
        boolean isMaxPoolOutputProvided = y != null;
        try (LibMatrixCuDNNPoolingDescriptors desc = LibMatrixCuDNNPoolingDescriptors.cudnnMaxpoolingBackwardDescriptors(gCtx, instName, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);){
            int status;
            long t1 = 0L;
            long t2 = 0L;
            long t3 = 0L;
            if (!isMaxPoolOutputProvided) {
                if (DMLScript.FINEGRAINED_STATISTICS) {
                    t1 = System.nanoTime();
                }
                long numBytes = N * C * P * Q * sizeOfDataType;
                y = gCtx.allocate(numBytes);
                if (DMLScript.FINEGRAINED_STATISTICS) {
                    GPUStatistics.maintainCPMiscTimes(instName, "nni", System.nanoTime() - t1);
                }
                if (DMLScript.FINEGRAINED_STATISTICS) {
                    t2 = System.nanoTime();
                }
                status = JCudnn.cudnnPoolingForward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnPoolingDescriptor)desc.poolingDesc, (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)desc.xDesc, (Pointer)x, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)desc.yDesc, (Pointer)y);
                if (DMLScript.FINEGRAINED_STATISTICS) {
                    GPUStatistics.maintainCPMiscTimes(instName, "nnmf", System.nanoTime() - t2);
                }
                if (status != 0) {
                    throw new DMLRuntimeException("Could not executed cudnnPoolingForward before cudnnPoolingBackward: " + cudnnStatus.stringFor((int)status));
                }
            }
            if (DMLScript.FINEGRAINED_STATISTICS) {
                t3 = System.nanoTime();
            }
            status = JCudnn.cudnnPoolingBackward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnPoolingDescriptor)desc.poolingDesc, (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)desc.yDesc, (Pointer)y, (cudnnTensorDescriptor)desc.dyDesc, (Pointer)dy, (cudnnTensorDescriptor)desc.xDesc, (Pointer)x, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)desc.dxDesc, (Pointer)dx);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instName, "nnmb", System.nanoTime() - t3);
            }
            if (status != 0) {
                throw new DMLRuntimeException("Could not executed cudnnPoolingBackward: " + cudnnStatus.stringFor((int)status));
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
        finally {
            long t4 = 0L;
            if (DMLScript.FINEGRAINED_STATISTICS) {
                t4 = System.nanoTime();
            }
            if (!isMaxPoolOutputProvided) {
                gCtx.cudaFreeHelper(instName, y);
            }
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instName, "nnc", System.nanoTime() - t4);
            }
        }
    }

    private static void cudnnReLU(GPUContext gCtx, String instName, MatrixObject in, Pointer dstData, cudnnTensorDescriptor srcTensorDesc) throws DMLRuntimeException {
        long t0 = 0L;
        try {
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("GPU : performCuDNNReLU, GPUContext=" + gCtx));
            }
            cudnnTensorDescriptor dstTensorDesc = srcTensorDesc;
            Pointer srcData = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, in, instName);
            cudnnActivationDescriptor activationDescriptor = new cudnnActivationDescriptor();
            JCudnn.cudnnCreateActivationDescriptor((cudnnActivationDescriptor)activationDescriptor);
            double dummy = -1.0;
            JCudnn.cudnnSetActivationDescriptor((cudnnActivationDescriptor)activationDescriptor, (int)1, (int)1, (double)dummy);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                t0 = System.nanoTime();
            }
            JCudnn.cudnnActivationForward((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnActivationDescriptor)activationDescriptor, (Pointer)LibMatrixCuDNN.one(), (cudnnTensorDescriptor)srcTensorDesc, (Pointer)srcData, (Pointer)LibMatrixCuDNN.zero(), (cudnnTensorDescriptor)dstTensorDesc, (Pointer)dstData);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instName, "nnaf", System.nanoTime() - t0);
            }
        }
        catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), (Exception)((Object)e));
        }
        finally {
            long t1 = 0L;
            if (DMLScript.FINEGRAINED_STATISTICS) {
                t1 = System.nanoTime();
            }
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instName, "nnc", System.nanoTime() - t1);
            }
        }
    }

    public static void relu(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in, String outputName) throws DMLRuntimeException {
        if (ec.getGPUContext(0) != gCtx) {
            throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
        }
        long N = in.getNumRows();
        long CHW = in.getNumColumns();
        MatrixObject output = ec.getMatrixObject(outputName);
        LibMatrixCuDNN.getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, in.getNumRows(), in.getNumColumns());
        long t0 = 0L;
        if (N * CHW >= maxNumElementsOfCuDNNTensor) {
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("GPU : relu custom kernel, GPUContext=" + gCtx));
            }
            if (DMLScript.FINEGRAINED_STATISTICS) {
                t0 = System.nanoTime();
            }
            Pointer dstData = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, output, instName);
            Pointer srcData = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, in, instName);
            LibMatrixCuDNN.getCudaKernels(gCtx).launchKernel("relu", ExecutionConfig.getConfigForSimpleMatrixOperations(LibMatrixCuDNN.toInt(N), LibMatrixCuDNN.toInt(CHW)), srcData, dstData, LibMatrixCuDNN.toInt(N), LibMatrixCuDNN.toInt(CHW));
            if (DMLScript.FINEGRAINED_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instName, "nnrk", System.nanoTime() - t0);
            }
        } else {
            cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
            JCudnn.cudnnCreateTensorDescriptor((cudnnTensorDescriptor)tensorDescriptor);
            JCudnn.cudnnSetTensor4dDescriptor((cudnnTensorDescriptor)tensorDescriptor, (int)0, (int)CUDNN_DATA_TYPE, (int)LibMatrixCuDNN.toInt(N), (int)1, (int)1, (int)LibMatrixCuDNN.toInt(CHW));
            LibMatrixCuDNN.cudnnReLU(gCtx, instName, in, LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, output, instName), tensorDescriptor);
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)tensorDescriptor);
        }
    }

    protected static Pointer getDensePointerForCuDNN(GPUContext gCtx, MatrixObject image, String instName) throws DMLRuntimeException {
        long numElems = image.getNumRows() * image.getNumColumns();
        if (numElems > maxNumElementsOfCuDNNTensor) {
            throw new DMLRuntimeException("CuDNN restriction: the size of input tensor cannot have greater than 2 giga-elements, but has " + numElems + " (i.e. [" + image.getNumRows() + " X " + image.getNumColumns() + "]). Hint: try reducing the mini-batch size.");
        }
        return LibMatrixCuDNN.getDensePointer(gCtx, image, instName);
    }

    protected static void checkStatus(int status) throws DMLRuntimeException {
        if (status != 0) {
            throw new DMLRuntimeException("Error status returned by CuDNN:" + cudnnStatus.stringFor((int)status));
        }
    }
}

