/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.util;

import java.util.Arrays;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.ScalarOperator;

public class ConvolutionUtils {
    public static String getConv2dOutputMap(String H, String R, String verticalStride, String heightPadding) {
        long padX2 = -1L;
        try {
            padX2 = Long.parseLong(heightPadding) * 2L;
            return "" + ConvolutionUtils.getP(Long.parseLong(H), Long.parseLong(R), Long.parseLong(verticalStride), Long.parseLong(heightPadding));
        }
        catch (Exception e) {
            if (padX2 == -1L) {
                return "((" + H + " + 2*" + heightPadding + " - " + R + ") / " + verticalStride + "+ 1)";
            }
            if (padX2 == 0L) {
                return "((" + H + " - " + R + ") / " + verticalStride + "+ 1)";
            }
            return "((" + H + " + " + padX2 + " - " + R + ") / " + verticalStride + "+ 1)";
        }
    }

    public static long getP(long H, long R, long verticalStride, long heightPadding) {
        if (H <= 0L || R <= 0L || heightPadding < 0L || verticalStride < 0L) {
            throw new RuntimeException("Incorrect parameters: height=" + H + " filter_height=" + R + " stride=" + verticalStride + " pad=" + heightPadding);
        }
        long padded_image_height = H + 2L * heightPadding;
        long ret = (padded_image_height - R) / verticalStride + 1L;
        if (ret <= 0L || ret > Integer.MAX_VALUE) {
            if (padded_image_height < R) {
                throw new RuntimeException("Incorrect parameters: padded image height:" + padded_image_height + " cannot be less than filter_height:" + R);
            }
            throw new RuntimeException("Incorrect parameters: height=" + H + " filter_height=" + R + " stride=" + verticalStride + " pad=" + heightPadding + " as P=" + ret);
        }
        return ret;
    }

    public static long getQ(long W, long S, long horizontalStride, long widthPadding) {
        if (W <= 0L || S <= 0L || widthPadding < 0L || horizontalStride < 0L) {
            throw new RuntimeException("Incorrect parameters: width=" + W + " filter_width=" + S + " stride=" + horizontalStride + " pad=" + widthPadding);
        }
        long padded_image_width = W + 2L * widthPadding;
        long ret = (padded_image_width - S) / horizontalStride + 1L;
        if (ret <= 0L || ret > Integer.MAX_VALUE) {
            if (padded_image_width < S) {
                throw new RuntimeException("Incorrect parameters: padded image width:" + padded_image_width + " cannot be less than filter width:" + S);
            }
            throw new RuntimeException("Incorrect parameters: width=" + W + " filter_width=" + S + " stride=" + horizontalStride + " pad=" + widthPadding + " as Q=" + ret);
        }
        return ret;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static void binaryOperationInPlace(MatrixBlock src, double[] dest, int destPos, int destNumCols, int src_rl, int src_ru, BinaryOperator op) throws DMLRuntimeException {
        if (src.isInSparseFormat()) {
            if (src.isEmptyBlock() && op.fn == Plus.getPlusFnObject()) return;
            if (src.isEmptyBlock() && op.fn == Multiply.getMultiplyFnObject()) {
                Arrays.fill(dest, destPos, destPos + (src_ru - src_rl) * destNumCols, 0.0);
                return;
            } else if (op.fn == Plus.getPlusFnObject()) {
                int i = src_rl;
                int cix = destPos;
                while (i < src_ru) {
                    if (!src.getSparseBlock().isEmpty(i)) {
                        int apos = src.getSparseBlock().pos(i);
                        int alen = src.getSparseBlock().size(i);
                        int[] aix = src.getSparseBlock().indexes(i);
                        double[] avals = src.getSparseBlock().values(i);
                        for (int j = apos; j < apos + alen; ++j) {
                            int n = cix + aix[j];
                            dest[n] = dest[n] + avals[j];
                        }
                    }
                    ++i;
                    cix += destNumCols;
                }
                return;
            } else {
                if (op.fn != Multiply.getMultiplyFnObject()) throw new DMLRuntimeException("Unimplemented sparse operation");
                int i = src_rl;
                int cix = destPos;
                while (i < src_ru) {
                    if (!src.getSparseBlock().isEmpty(i)) {
                        int apos = src.getSparseBlock().pos(i);
                        int alen = src.getSparseBlock().size(i);
                        int[] aix = src.getSparseBlock().indexes(i);
                        double[] avals = src.getSparseBlock().values(i);
                        int prevDestIndex = 0;
                        for (int j = apos; j < apos + alen; ++j) {
                            Arrays.fill(dest, cix + prevDestIndex, cix + aix[j], 0.0);
                            prevDestIndex = aix[j] + 1;
                            int n = cix + aix[j];
                            dest[n] = dest[n] * avals[j];
                        }
                        Arrays.fill(dest, cix + prevDestIndex, cix + destNumCols, 0.0);
                    } else {
                        Arrays.fill(dest, cix, cix + destNumCols, 0.0);
                    }
                    ++i;
                    cix += destNumCols;
                }
            }
            return;
        } else {
            double[] inputArr = src.getDenseBlock();
            if (op.fn == Plus.getPlusFnObject()) {
                for (int i = destPos; i < src_ru * destNumCols; ++i) {
                    int n = i;
                    dest[n] = dest[n] + inputArr[i];
                }
                return;
            } else if (op.fn == Multiply.getMultiplyFnObject()) {
                for (int i = destPos; i < src_ru * destNumCols; ++i) {
                    int n = i;
                    dest[n] = dest[n] * inputArr[i];
                }
                return;
            } else {
                for (int i = destPos; i < src_ru * destNumCols; ++i) {
                    dest[i] = op.fn.execute(dest[i], inputArr[i]);
                }
            }
        }
    }

    public static void scalarOperations(MatrixBlock src, double[] dest, int destPos, int destNumCols, int src_rl, int src_ru, ScalarOperator scalarOp) throws DMLRuntimeException {
        if (src.isInSparseFormat()) {
            int i = src_rl;
            int cix = destPos;
            while (i < src_ru) {
                if (!src.getSparseBlock().isEmpty(i)) {
                    int apos = src.getSparseBlock().pos(i);
                    int alen = src.getSparseBlock().size(i);
                    int[] aix = src.getSparseBlock().indexes(i);
                    double[] avals = src.getSparseBlock().values(i);
                    for (int j = apos; j < apos + alen; ++j) {
                        dest[cix + aix[j]] = scalarOp.executeScalar(avals[j]);
                    }
                }
                ++i;
                cix += destNumCols;
            }
        } else {
            double[] inputArr = src.getDenseBlock();
            for (int i = destPos; i < src_ru * destNumCols; ++i) {
                dest[i] = scalarOp.executeScalar(inputArr[i]);
            }
        }
    }

    public static void fillBias(MatrixBlock bias, double[] outputArray, int src_rl, int src_ru, int N, int K, int PQ) throws DMLRuntimeException {
        if (bias.isInSparseFormat()) {
            for (int k = 0; k < K; ++k) {
                if (bias.getSparseBlock().isEmpty(k)) continue;
                int apos = bias.getSparseBlock().pos(k);
                double[] avals = bias.getSparseBlock().values(k);
                double val = avals[apos];
                for (int n = src_rl; n < src_ru; ++n) {
                    int fromIndex = n * K * PQ + k * PQ;
                    Arrays.fill(outputArray, fromIndex, fromIndex + PQ, val);
                }
            }
        } else {
            double[] biasArr = bias.getDenseBlock();
            for (int n = src_rl; n < src_ru; ++n) {
                for (int k = 0; k < K; ++k) {
                    int fromIndex = n * K * PQ + k * PQ;
                    double val = biasArr[k];
                    Arrays.fill(outputArray, fromIndex, fromIndex + PQ, val);
                }
            }
        }
    }
}

