/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.Arrays;
import java.util.concurrent.Callable;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;

public class LibMatrixDNNPoolingBackwardHelper {

    public static class PoolingBackwardSparseSparse
    extends PoolingBackwardSparseDense {
        public PoolingBackwardSparseSparse(int rl, int ru, ConvolutionParameters params, boolean relu) {
            super(rl, ru, params, relu, params.input2, params.output);
            if (this.output.getDenseBlock() == null) {
                throw new RuntimeException("Incorrect usage: empty outputs");
            }
            if (!params.input1.isInSparseFormat() || !params.input2.isInSparseFormat()) {
                throw new RuntimeException("Incorrect usage: Call optimized versions");
            }
        }

        @Override
        protected void maxpoolingBackward(int[] maxIx, int outOffset, int n, int c, int C, int Q, int PQ, int CPQ) {
            SparseBlock sblock = this.doutput.getSparseBlock();
            double[] out = this.output.getDenseBlock();
            if (sblock.isEmpty(n)) {
                return;
            }
            int apos = sblock.pos(n);
            int alen = sblock.size(n);
            int[] aix = sblock.indexes(n);
            double[] avals = sblock.values(n);
            int cpos = c == 0 ? 0 : sblock.posFIndexGTE(n, c * PQ);
            int cpos2 = c + 1 == C ? alen : sblock.posFIndexGTE(n, (c + 1) * PQ);
            cpos = cpos >= 0 ? cpos : alen;
            cpos2 = cpos2 >= 0 ? cpos2 : alen;
            for (int j = apos + cpos; j < apos + cpos2; ++j) {
                int p = aix[j] % PQ / Q;
                int q = aix[j] % Q;
                int pq = p * Q + q;
                int n2 = outOffset + maxIx[pq];
                out[n2] = out[n2] + avals[j];
            }
        }
    }

    public static class PoolingBackwardSparseDense
    implements Callable<Long> {
        private final int _rl;
        private final int _ru;
        private final ConvolutionParameters _params;
        private final boolean reluBack;
        protected final MatrixBlock doutput;
        protected final MatrixBlock output;

        protected PoolingBackwardSparseDense(int rl, int ru, ConvolutionParameters params, boolean relu, MatrixBlock dout, MatrixBlock out) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this.reluBack = relu;
            this.doutput = dout;
            this.output = out;
        }

        public PoolingBackwardSparseDense(int rl, int ru, ConvolutionParameters params, boolean relu) {
            this(rl, ru, params, relu, params.input2, params.output);
            if (this.doutput.getDenseBlock() == null || this.output.getDenseBlock() == null) {
                throw new RuntimeException("Incorrect usage: empty inputs");
            }
            if (!params.input1.isInSparseFormat()) {
                throw new RuntimeException("Incorrect usage: sparse input1 expected");
            }
        }

        @Override
        public Long call() throws Exception {
            int P = this._params.P;
            int Q = this._params.Q;
            int W = this._params.W;
            int C = this._params.C;
            int R = this._params.R;
            int S = this._params.S;
            int padh = this._params.pad_h;
            int padw = this._params.pad_w;
            int strideh = this._params.stride_h;
            int stridew = this._params.stride_w;
            int PQ = this._params.P * this._params.Q;
            int CPQ = this._params.C * this._params.P * this._params.Q;
            int HW = this._params.H * this._params.W;
            int CHW = this._params.C * this._params.H * this._params.W;
            double[] maxVal = new double[PQ];
            int[] maxIx = new int[PQ];
            for (int n = this._rl; n < this._ru; ++n) {
                for (int c = 0; c < C; ++c) {
                    int outOffset = n * CHW + c * HW;
                    this.maxpoolingForward(maxVal, maxIx, n, c, padh, padw, strideh, stridew, C, P, Q, R, S, HW, W);
                    this.maxpoolingBackward(maxIx, outOffset, n, c, C, Q, PQ, CPQ);
                }
            }
            return this.output.recomputeNonZeros(this._rl, this._ru - 1);
        }

        protected void maxpoolingForward(double[] maxVal, int[] maxIx, int n, int c, int padh, int padw, int strideh, int stridew, int C, int P, int Q, int R, int S, int HW, int W) {
            SparseBlock sblock = this._params.input1.getSparseBlock();
            if (!sblock.isEmpty(n)) {
                Arrays.fill(maxVal, -1.7976931348623157E308);
                int apos = sblock.pos(n);
                int alen = sblock.size(n);
                int[] aix = sblock.indexes(n);
                double[] avals = sblock.values(n);
                int cpos = c == 0 ? 0 : sblock.posFIndexGTE(n, c * HW);
                int cpos2 = c + 1 == C ? alen : sblock.posFIndexGTE(n, (c + 1) * HW);
                cpos = cpos >= 0 ? cpos : alen;
                cpos2 = cpos2 >= 0 ? cpos2 : alen;
                int lastix = c * HW - 1;
                for (int j = apos + cpos; j < apos + cpos2; ++j) {
                    PoolingBackwardSparseDense.update0(lastix + 1, aix[j], maxVal, maxIx, padh, padw, strideh, stridew, P, Q, R, S, HW, W);
                    int h = aix[j] % HW / W;
                    int w = aix[j] % W;
                    double val = this.reluBack && avals[j] < 0.0 ? 0.0 : avals[j];
                    PoolingBackwardSparseDense.update(val, maxVal, maxIx, h, w, padh, padw, strideh, stridew, P, Q, R, S, W);
                    lastix = aix[j];
                }
                PoolingBackwardSparseDense.update0(lastix + 1, (c + 1) * HW, maxVal, maxIx, padh, padw, strideh, stridew, P, Q, R, S, HW, W);
            } else {
                Arrays.fill(maxVal, 0.0);
                int ix = 0;
                for (int p = 0; p < P; ++p) {
                    int h = Math.max(-padh + p * strideh, 0);
                    int q = 0;
                    while (q < Q) {
                        int w = Math.max(-padw + q * stridew, 0);
                        maxIx[ix] = h * W + w;
                        ++q;
                        ++ix;
                    }
                }
            }
        }

        protected void maxpoolingBackward(int[] maxIx, int outOffset, int n, int c, int C, int Q, int PQ, int CPQ) {
            double[] dout = this.doutput.getDenseBlock();
            double[] out = this.output.getDenseBlock();
            int doutOffset = n * CPQ + c * PQ;
            for (int pq = 0; pq < PQ; ++pq) {
                int n2 = outOffset + maxIx[pq];
                out[n2] = out[n2] + dout[doutOffset + pq];
            }
        }

        private static void update0(int lix, int uix, double[] maxVal, int[] maxIx, int padh, int padw, int strideh, int stridew, int P, int Q, int R, int S, int HW, int W) {
            for (int i = lix; i < uix; ++i) {
                PoolingBackwardSparseDense.update(0.0, maxVal, maxIx, i % HW / W, i % W, padh, padw, strideh, stridew, P, Q, R, S, W);
            }
        }

        private static void update(double val, double[] maxVal, int[] maxIx, int h, int w, int padh, int padw, int strideh, int stridew, int P, int Q, int R, int S, int W) {
            int lp = Math.max((h + padh - R + strideh) / strideh, 0);
            int up = Math.min((h + padh + strideh) / strideh, P);
            int lq = Math.max((w + padw - S + stridew) / stridew, 0);
            int uq = Math.min((w + padw + stridew) / stridew, Q);
            int maxIndex = h * W + w;
            for (int p = lp; p < up; ++p) {
                for (int q = lq; q < uq; ++q) {
                    int ix = p * Q + q;
                    if (!(maxVal[ix] < val)) continue;
                    maxVal[ix] = val;
                    maxIx[ix] = maxIndex;
                }
            }
        }
    }

    public static class PoolingBackwardDenseSparse
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;
        MatrixBlock output;
        boolean performReluBackward;
        double[] inputArray;
        MatrixBlock dout;
        int C;
        int CHW;
        int P;
        int Q;
        int HW;

        public PoolingBackwardDenseSparse(int rl, int ru, ConvolutionParameters params, boolean performReluBackward) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this.performReluBackward = performReluBackward;
            this.inputArray = params.input1.getDenseBlock();
            this.dout = params.input2;
            this.output = params.output;
            this.C = params.C;
            this.CHW = params.C * params.H * params.W;
            this.HW = params.H * params.W;
            this.P = params.P;
            this.Q = params.Q;
            if (this.inputArray == null || this.output.getDenseBlock() == null) {
                throw new RuntimeException("Incorrect usage: empty inputs");
            }
            if (!params.input2.isInSparseFormat()) {
                throw new RuntimeException("Incorrect usage: Call optimized versions");
            }
        }

        @Override
        public Long call() throws Exception {
            double[] out = this.output.getDenseBlock();
            for (int n = this._rl; n < this._ru; ++n) {
                if (this.dout.sparseBlock.isEmpty(n)) continue;
                int[] tensorIndexes = new int[3];
                int apos = this.dout.sparseBlock.pos(n);
                int alen = this.dout.sparseBlock.size(n);
                int[] aix = this.dout.sparseBlock.indexes(n);
                double[] avals = this.dout.sparseBlock.values(n);
                for (int j = apos; j < apos + alen; ++j) {
                    LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, this.P, this.Q);
                    int c = tensorIndexes[0];
                    int p = tensorIndexes[1];
                    int q = tensorIndexes[2];
                    int inputOffset = n * this.CHW + c * this.HW;
                    int maxIndex = LibMatrixDNNHelper.getMaxIndex(p, q, inputOffset, this.inputArray, this._params, this.performReluBackward);
                    if (maxIndex == -1) continue;
                    int n2 = maxIndex;
                    out[n2] = out[n2] + avals[j];
                }
            }
            return this.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }

    public static class PoolingBackwardDenseDense
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;
        boolean performReluBackward;
        double[] inputArray;
        double[] doutArray;
        MatrixBlock output;
        int C;
        int CHW;
        int P;
        int Q;
        int HW;
        int CPQ;
        int PQ;

        public PoolingBackwardDenseDense(int rl, int ru, ConvolutionParameters params, boolean performReluBackward) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this.performReluBackward = performReluBackward;
            this.inputArray = params.input1.getDenseBlock();
            this.doutArray = params.input2.getDenseBlock();
            this.output = params.output;
            this.C = params.C;
            this.CHW = params.C * params.H * params.W;
            this.HW = params.H * params.W;
            this.P = params.P;
            this.Q = params.Q;
            this.CPQ = params.C * params.P * params.Q;
            this.PQ = params.P * params.Q;
            if (this.inputArray == null || this.doutArray == null || this.output.getDenseBlock() == null) {
                throw new RuntimeException("Incorrect usage: empty inputs");
            }
        }

        @Override
        public Long call() throws Exception {
            double[] out = this.output.getDenseBlock();
            for (int n = this._rl; n < this._ru; ++n) {
                for (int c = 0; c < this.C; ++c) {
                    int inputOffset = n * this.CHW + c * this.HW;
                    int outputOffset = n * this.CPQ + c * this.PQ;
                    for (int p = 0; p < this.P; ++p) {
                        for (int q = 0; q < this.Q; ++q) {
                            int maxIndex = LibMatrixDNNHelper.getMaxIndex(p, q, inputOffset, this.inputArray, this._params, this.performReluBackward);
                            if (maxIndex == -1) continue;
                            int n2 = maxIndex;
                            out[n2] = out[n2] + this.doutArray[outputOffset + p * this.Q + q];
                        }
                    }
                }
            }
            return this.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }
}

