/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.concurrent.Callable;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNIm2ColHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNRotate180Helper;
import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.utils.NativeHelper;

public class LibMatrixDNNConv2dBackwardFilterHelper {
    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void inplaceAdd(double[] a, ConvolutionParameters params) {
        double[] dArray = params.output.denseBlock;
        synchronized (params.output.denseBlock) {
            LibMatrixMult.vectAdd(a, params.output.denseBlock, 0, 0, a.length);
            // ** MonitorExit[var2_2] (shouldn't be in output)
            return;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void inplaceTransAdd(double[] a, ConvolutionParameters params) {
        double[] dArray = params.output.denseBlock;
        synchronized (params.output.denseBlock) {
            double[] c = params.output.denseBlock;
            int CRS = params.C * params.R * params.S;
            int K = params.K;
            int blocksizeIJ = 128;
            for (int bi = 0; bi < CRS; bi += 128) {
                for (int bj = 0; bj < K; bj += 128) {
                    int bimin = Math.min(bi + 128, CRS);
                    int bjmin = Math.min(bj + 128, K);
                    int i = bi;
                    int aix = bi * K;
                    while (i < bimin) {
                        int j = bj;
                        int cix = i + bj * CRS;
                        while (j < bjmin) {
                            int n = cix;
                            c[n] = c[n] + a[aix + j];
                            ++j;
                            cix += CRS;
                        }
                        ++i;
                        aix += K;
                    }
                }
            }
            // ** MonitorExit[var2_2] (shouldn't be in output)
            return;
        }
    }

    public static class Conv2dBackwardFilterTrans
    implements Callable<Long> {
        private final int _rl;
        private final int _ru;
        private final ConvolutionParameters _params;

        public Conv2dBackwardFilterTrans(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock dout = this._params.input2;
            MatrixBlock im2ColOutBlock = new MatrixBlock(PQ, CRS, false).allocateBlock();
            MatrixBlock outRotate = new MatrixBlock(K, PQ, dout.sparse).allocateBlock();
            MatrixBlock outMM = new MatrixBlock(K, CRS, false).allocateBlock();
            LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker(this._params.input1, im2ColOutBlock, this._params, true, true);
            LibMatrixDNNRotate180Helper.Rotate180Worker rotate180Worker = LibMatrixDNNRotate180Helper.Rotate180Worker.getWorker(dout, outRotate, this._params, true, true);
            double[] partRet = new double[CRS * this._params.K];
            long time1 = 0L;
            long time2 = 0L;
            for (int n = this._rl; n < this._ru; ++n) {
                long t3;
                rotate180Worker.execute(n, 0);
                long t1 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                im2ColWorker.execute(n);
                long t2 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                outMM.reset(K, CRS, false);
                LibMatrixDNNHelper.singleThreadedMatMult(outRotate, im2ColOutBlock, outMM, !outRotate.sparse, !im2ColOutBlock.sparse, this._params);
                long l = t3 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                if (!outMM.isEmptyBlock()) {
                    LibMatrixMult.vectAdd(outMM.getDenseBlock(), partRet, 0, 0, K * CRS);
                }
                if (!DMLScript.FINEGRAINED_STATISTICS) continue;
                time1 += t2 - t1;
                time2 += t3 - t2;
            }
            LibMatrixDNNConv2dBackwardFilterHelper.inplaceAdd(partRet, this._params);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                LibMatrixDNN.loopedConvBwdFilterIm2ColTime.addAndGet(time1);
                LibMatrixDNN.loopedConvBwdFilterMatMultTime.addAndGet(time2);
            }
            return 0L;
        }
    }

    public static class Conv2dBackwardFilter
    implements Callable<Long> {
        private final int _rl;
        private final int _ru;
        private final ConvolutionParameters _params;

        public Conv2dBackwardFilter(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock dout = this._params.input2;
            MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, false);
            MatrixBlock outRotate = new MatrixBlock(PQ, K, dout.sparse);
            MatrixBlock outMM = new MatrixBlock(CRS, K, false);
            outRotate.allocateBlock();
            LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker(this._params.input1, im2ColOutBlock, this._params, true, false);
            LibMatrixDNNRotate180Helper.Rotate180Worker rotate180Worker = LibMatrixDNNRotate180Helper.Rotate180Worker.getWorker(dout, outRotate, this._params, true, false);
            double[] partRet = new double[CRS * this._params.K];
            long time1 = 0L;
            long time2 = 0L;
            for (int n = this._rl; n < this._ru; ++n) {
                long t3;
                rotate180Worker.execute(n, 0);
                long t1 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                im2ColWorker.execute(n);
                long t2 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                outMM.reset(CRS, K, false);
                LibMatrixDNNHelper.singleThreadedMatMult(im2ColOutBlock, outRotate, outMM, !im2ColOutBlock.sparse, !outRotate.sparse, this._params);
                long l = t3 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                if (!outMM.isEmptyBlock()) {
                    LibMatrixMult.vectAdd(outMM.getDenseBlock(), partRet, 0, 0, K * CRS);
                }
                if (!DMLScript.FINEGRAINED_STATISTICS) continue;
                time1 += t2 - t1;
                time2 += t3 - t2;
            }
            LibMatrixDNNConv2dBackwardFilterHelper.inplaceTransAdd(partRet, this._params);
            if (DMLScript.FINEGRAINED_STATISTICS) {
                LibMatrixDNN.loopedConvBwdFilterIm2ColTime.addAndGet(time1);
                LibMatrixDNN.loopedConvBwdFilterMatMultTime.addAndGet(time2);
            }
            return 0L;
        }
    }

    public static class SparseNativeConv2dBackwardFilterDense
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;

        public SparseNativeConv2dBackwardFilterDense(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int CRS = this._params.C * this._params.R * this._params.S;
            int PQ = this._params.P * this._params.Q;
            int K = this._params.K;
            MatrixBlock dout_n = new MatrixBlock(PQ, K, false);
            dout_n.allocateBlock();
            LibMatrixDNNRotate180Helper.Rotate180Worker rotate180Worker = LibMatrixDNNRotate180Helper.Rotate180Worker.getWorker(this._params.input2, dout_n, this._params, true, false);
            double[] ldout_n = dout_n.getDenseBlock();
            double[] partRet = new double[CRS * this._params.K];
            for (int n = this._rl; n < this._ru; ++n) {
                if (this._params.input1.getSparseBlock().isEmpty(n)) continue;
                rotate180Worker.execute(n, 0);
                int apos = this._params.input1.getSparseBlock().pos(n);
                int alen = this._params.input1.getSparseBlock().size(n);
                int[] aix = this._params.input1.getSparseBlock().indexes(n);
                double[] avals = this._params.input1.getSparseBlock().values(n);
                NativeHelper.conv2dBackwardFilterSparseDense(apos, alen, aix, avals, ldout_n, partRet, 1, this._params.C, this._params.H, this._params.W, this._params.K, this._params.R, this._params.S, this._params.stride_h, this._params.stride_w, this._params.pad_h, this._params.pad_w, this._params.P, this._params.Q, 1);
            }
            LibMatrixDNNConv2dBackwardFilterHelper.inplaceTransAdd(partRet, this._params);
            return 0L;
        }
    }
}

