/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.concurrent.Callable;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNIm2ColHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.utils.NativeHelper;

public class LibMatrixDNNConv2dHelper {

    public static class SparseNativeConv2d
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;

        public SparseNativeConv2d(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int KPQ = this._params.K * this._params.P * this._params.Q;
            double[] temp = new double[KPQ];
            for (int n = this._rl; n < this._ru; ++n) {
                if (this._params.input1.getSparseBlock().isEmpty(n)) continue;
                int apos = this._params.input1.getSparseBlock().pos(n);
                int alen = this._params.input1.getSparseBlock().size(n);
                int[] aix = this._params.input1.getSparseBlock().indexes(n);
                double[] avals = this._params.input1.getSparseBlock().values(n);
                NativeHelper.conv2dSparse(apos, alen, aix, avals, this._params.input2.getDenseBlock(), temp, 1, this._params.C, this._params.H, this._params.W, this._params.K, this._params.R, this._params.S, this._params.stride_h, this._params.stride_w, this._params.pad_h, this._params.pad_w, this._params.P, this._params.Q, 1);
                System.arraycopy(temp, 0, this._params.output.denseBlock, n * KPQ, KPQ);
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }

    public static class LoopedIm2ColConv2dTransAllChan
    extends LoopedIm2ColConv2dAllChan {
        public LoopedIm2ColConv2dTransAllChan(int rl, int ru, ConvolutionParameters params) {
            super(rl, ru, params);
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock outIm2col = new MatrixBlock(PQ, CRS, false);
            MatrixBlock outMM = new MatrixBlock(PQ, K, false);
            LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker(this._params.input1, outIm2col, this._params, true, true);
            for (int n = this._rl; n < this._ru; ++n) {
                im2ColWorker.execute(n);
                outMM.reset(outMM.rlen, outMM.clen, false);
                LibMatrixDNNHelper.singleThreadedMatMult(outIm2col, this._params.input2, outMM, false, false, this._params);
                LoopedIm2ColConv2dTransAllChan.partialCopyTrans(outMM, this._params.output, n * K * PQ, K, PQ);
                if (this._params.bias == null) continue;
                LibMatrixDNNHelper.addBias(n, this._params.output.getDenseBlock(), this._params.bias.getDenseBlock(), K, PQ);
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }

        private static void partialCopyTrans(MatrixBlock src, MatrixBlock dest, int destPos, int K, int PQ) {
            if (src.isEmptyBlock()) {
                return;
            }
            if (src.isInSparseFormat()) {
                SparseBlock sblock = src.sparseBlock;
                double[] c = dest.denseBlock;
                for (int i = 0; i < src.getNumRows(); ++i) {
                    if (sblock.isEmpty(i)) continue;
                    int apos = sblock.pos(i);
                    int alen = sblock.size(i);
                    int[] aix = sblock.indexes(i);
                    double[] avals = sblock.values(i);
                    int desPosK = destPos + i;
                    for (int j = apos; j < apos + alen; ++j) {
                        c[desPosK + aix[j] * PQ] = avals[j];
                    }
                }
            } else {
                double[] a = src.denseBlock;
                double[] c = dest.denseBlock;
                int blocksizeIJ = 128;
                for (int bi = 0; bi < PQ; bi += 128) {
                    for (int bj = 0; bj < K; bj += 128) {
                        int bimin = Math.min(bi + 128, PQ);
                        int bjmin = Math.min(bj + 128, K);
                        int i = bi;
                        int aix = bi * K + bj;
                        int cix = bj * PQ + bi;
                        while (i < bimin) {
                            LibMatrixReorg.transposeRow(a, c, aix, destPos + cix, PQ, bjmin - bj);
                            ++i;
                            aix += K;
                            ++cix;
                        }
                    }
                }
            }
        }
    }

    public static class LoopedIm2ColConv2dAllChan
    implements Callable<Long> {
        protected final int _rl;
        protected final int _ru;
        protected final ConvolutionParameters _params;

        public LoopedIm2ColConv2dAllChan(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock outIm2col = new MatrixBlock(CRS, PQ, false);
            MatrixBlock outMM = new MatrixBlock(K, PQ, false);
            LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker(this._params.input1, outIm2col, this._params, true, false);
            long time1 = 0L;
            long time2 = 0L;
            for (int n = this._rl; n < this._ru; ++n) {
                long t3;
                long t1 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                im2ColWorker.execute(n);
                long t2 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                outMM.reset(outMM.rlen, outMM.clen, false);
                LibMatrixDNNHelper.singleThreadedMatMult(this._params.input2, outIm2col, outMM, false, true, this._params);
                long l = t3 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                if (DMLScript.FINEGRAINED_STATISTICS) {
                    time1 += t2 - t1;
                    time2 += t3 - t2;
                }
                LoopedIm2ColConv2dAllChan.partialCopy1(outMM, this._params.output.getDenseBlock(), n * K * PQ, K, PQ);
                if (this._params.bias == null) continue;
                LibMatrixDNNHelper.addBias(n, this._params.output.getDenseBlock(), this._params.bias.getDenseBlock(), K, PQ);
            }
            if (DMLScript.FINEGRAINED_STATISTICS) {
                LibMatrixDNN.loopedConvIm2ColTime.addAndGet(time1);
                LibMatrixDNN.loopedConvMatMultTime.addAndGet(time2);
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }

        private static void partialCopy1(MatrixBlock src, double[] dest, int destPos, int K, int PQ) {
            if (src.isEmptyBlock()) {
                return;
            }
            if (src.isInSparseFormat()) {
                SparseBlock sblock = src.sparseBlock;
                for (int k = 0; k < src.getNumRows(); ++k) {
                    if (sblock.isEmpty(k)) continue;
                    int apos = sblock.pos(k);
                    int alen = sblock.size(k);
                    int[] aix = sblock.indexes(k);
                    double[] avals = sblock.values(k);
                    int desPosK = destPos + k * PQ;
                    for (int j = apos; j < apos + alen; ++j) {
                        dest[desPosK + aix[j]] = avals[j];
                    }
                }
            } else {
                System.arraycopy(src.denseBlock, 0, dest, destPos, K * PQ);
            }
        }
    }

    public static class LoopedIm2ColConv2dOneChan
    implements Callable<Long> {
        protected final int _rl;
        protected final int _ru;
        protected final ConvolutionParameters _params;
        protected final ArrayList<MatrixBlock> _filters;

        public LoopedIm2ColConv2dOneChan(int rl, int ru, ConvolutionParameters params, ArrayList<MatrixBlock> filters) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this._filters = filters;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K = this._params.K;
            int RS = this._params.R * this._params.S;
            MatrixBlock im2ColOutBlock = new MatrixBlock(RS, PQ, false);
            LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker(this._params.input1, im2ColOutBlock, this._params, false, false);
            long time1 = 0L;
            long time2 = 0L;
            for (int n = this._rl; n < this._ru; ++n) {
                for (int c = 0; c < this._params.C; ++c) {
                    long t3;
                    long t1 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                    im2ColWorker.execute(n, c);
                    long t2 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                    MatrixBlock matMultOutBlock = new MatrixBlock(K, PQ, false);
                    LibMatrixDNNHelper.singleThreadedMatMult(this._filters.get(c), im2ColOutBlock, matMultOutBlock, false, true, this._params);
                    long l = t3 = DMLScript.FINEGRAINED_STATISTICS ? System.nanoTime() : 0L;
                    if (DMLScript.FINEGRAINED_STATISTICS) {
                        time1 += t2 - t1;
                        time2 += t3 - t2;
                    }
                    LoopedIm2ColConv2dOneChan.add(matMultOutBlock, this._params.output.getDenseBlock(), n * K * PQ, K, PQ);
                }
                if (this._params.bias == null) continue;
                LibMatrixDNNHelper.addBias(n, this._params.output.getDenseBlock(), this._params.bias.getDenseBlock(), K, PQ);
            }
            if (DMLScript.FINEGRAINED_STATISTICS) {
                LibMatrixDNN.loopedConvIm2ColTime.addAndGet(time1);
                LibMatrixDNN.loopedConvMatMultTime.addAndGet(time2);
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }

        private static void add(MatrixBlock src, double[] dest, int destPos, int K, int PQ) {
            if (!src.isEmptyBlock()) {
                if (src.isInSparseFormat()) {
                    for (int k = 0; k < src.getNumRows(); ++k) {
                        if (src.sparseBlock.isEmpty(k)) continue;
                        int apos = src.sparseBlock.pos(k);
                        int alen = src.sparseBlock.size(k);
                        int[] aix = src.sparseBlock.indexes(k);
                        double[] avals = src.sparseBlock.values(k);
                        int desPosK = destPos + k * PQ;
                        for (int j = apos; j < apos + alen; ++j) {
                            int pqIndex = aix[j];
                            int n = desPosK + pqIndex;
                            dest[n] = dest[n] + avals[j];
                        }
                    }
                } else {
                    LibMatrixMult.vectAdd(src.denseBlock, dest, 0, destPos, K * PQ);
                }
            }
        }
    }
}

