/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.cp;

import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.runtime.util.ConvolutionUtils;
import org.apache.sysml.utils.NativeHelper;

public class ConvolutionCPInstruction
extends UnaryCPInstruction {
    private static final Log LOG = LogFactory.getLog(ConvolutionCPInstruction.class.getName());
    private static boolean warnedUnderUtilitization = false;
    private final CPOperand _in2;
    private final CPOperand _in3;
    private final ArrayList<CPOperand> _input_shape;
    private final ArrayList<CPOperand> _filter_shape;
    private final ArrayList<CPOperand> _stride;
    private final ArrayList<CPOperand> _padding;
    private final int _numThreads;
    private final double _intermediateMemoryBudget;

    public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget, String opcode, String istr) {
        super(CPInstruction.CPType.Convolution, null, in, out, opcode, istr);
        this._in2 = in2;
        this._in3 = in3;
        this._stride = stride;
        this._padding = padding;
        this._input_shape = input_shape;
        this._filter_shape = filter_shape;
        this._numThreads = numThreads;
        this._intermediateMemoryBudget = intermediateMemoryBudget;
    }

    public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, int numThreads, double intermediateMemoryBudget) throws DMLRuntimeException {
        this(in, in2, null, out, null, null, null, null, numThreads, intermediateMemoryBudget, opcode, istr);
        if (!(opcode.equals("bias_add") || opcode.equals("relu_backward") || opcode.equals("bias_multiply"))) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + opcode);
        }
    }

    public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, int numThreads, double intermediateMemoryBudget) throws DMLRuntimeException {
        this(in, in2, in3, out, null, null, null, null, numThreads, intermediateMemoryBudget, opcode, istr);
        if (!opcode.equals("channel_sums")) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode);
        }
    }

    private ConvolutionCPInstruction(CPOperand in, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) {
        this(in, null, null, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
    }

    public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) {
        this(in, in2, null, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
    }

    public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) {
        this(in, in2, in3, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
    }

    public static ConvolutionCPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling") || opcode.equalsIgnoreCase("avgpooling")) {
            InstructionUtils.checkNumFields(parts, 16);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand out = new CPOperand(parts[14]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[2]));
            stride.add(new CPOperand(parts[3]));
            padding.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            input_shape.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            filter_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            int k = Integer.parseInt(parts[15]);
            return new ConvolutionCPInstruction(in, out, opcode, str, stride, padding, input_shape, filter_shape, k, Double.parseDouble(parts[16]));
        }
        if (opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("relu_maxpooling_backward") || opcode.equalsIgnoreCase("avgpooling_backward") || opcode.equalsIgnoreCase("conv2d") || opcode.equalsIgnoreCase("conv2d_backward_filter") || opcode.equalsIgnoreCase("conv2d_backward_data")) {
            InstructionUtils.checkNumFields(parts, 17);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[15]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[3]));
            stride.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            int k = Integer.parseInt(parts[16]);
            return new ConvolutionCPInstruction(in, in2, out, opcode, str, stride, padding, input_shape, filter_shape, k, Double.parseDouble(parts[17]));
        }
        if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
            InstructionUtils.checkNumFields(parts, 18);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand out = new CPOperand(parts[16]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[4]));
            stride.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            padding.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            input_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            filter_shape.add(new CPOperand(parts[15]));
            int k = Integer.parseInt(parts[17]);
            return new ConvolutionCPInstruction(in, in2, in3, out, opcode, str, stride, padding, input_shape, filter_shape, k, Double.parseDouble(parts[18]));
        }
        if (opcode.equalsIgnoreCase("bias_add") || opcode.equals("relu_backward") || opcode.equalsIgnoreCase("bias_multiply")) {
            InstructionUtils.checkNumFields(parts, 5);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            int k = Integer.parseInt(parts[4]);
            return new ConvolutionCPInstruction(in, in2, out, opcode, str, k, Double.parseDouble(parts[5]));
        }
        if (opcode.equalsIgnoreCase("channel_sums")) {
            InstructionUtils.checkNumFields(parts, 4);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand out = new CPOperand(parts[4]);
            return new ConvolutionCPInstruction(in, in2, in3, out, opcode, str, -1, 0.0);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionCPInstruction: " + str);
    }

    private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) throws DMLRuntimeException {
        return (int)ec.getScalarInput(aL.get(index).getName(), aL.get(index).getValueType(), aL.get(index).isLiteral()).getLongValue();
    }

    public void processReluBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
        MatrixBlock input = ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        MatrixBlock dout = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        MatrixBlock outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), input.isInSparseFormat() || dout.isInSparseFormat());
        if (!input.isEmpty() && !dout.isEmpty()) {
            outputBlock.allocateBlock();
            LibMatrixDNN.reluBackward(input, dout, outputBlock, this._numThreads);
        }
        ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock, this.getExtendedOpcode());
    }

    public void processBiasAddInstruction(ExecutionContext ec) throws DMLRuntimeException {
        MatrixBlock input = ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        MatrixBlock bias = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        MatrixBlock outputBlock = null;
        if (bias.getNumColumns() != 1) {
            throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns());
        }
        if (input.isEmpty() && bias.isEmpty()) {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true);
        } else if (bias.isEmpty()) {
            outputBlock = new MatrixBlock(input);
        } else {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false);
            outputBlock.allocateDenseBlock();
            LibMatrixDNN.biasAdd(input, bias, outputBlock, this._numThreads);
        }
        ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock, this.getExtendedOpcode());
    }

    public void processBiasMultiplyInstruction(ExecutionContext ec) throws DMLRuntimeException {
        MatrixBlock input = ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        MatrixBlock bias = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        MatrixBlock outputBlock = null;
        if (bias.getNumColumns() != 1) {
            throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns());
        }
        if (bias.isEmpty()) {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true);
        } else {
            outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), input.isInSparseFormat()).allocateBlock();
            LibMatrixDNN.biasMultiply(input, bias, outputBlock, this._numThreads);
        }
        ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock, this.getExtendedOpcode());
    }

    public void processChannelSumsInstruction(ExecutionContext ec) throws DMLRuntimeException {
        int HW;
        MatrixBlock input = ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        int C = (int)ec.getScalarInput(this._in2.getName(), this._in2.getValueType(), this._in2.isLiteral()).getLongValue();
        if (C * (HW = (int)ec.getScalarInput(this._in3.getName(), this._in3.getValueType(), this._in3.isLiteral()).getLongValue()) != input.getNumColumns()) {
            throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns());
        }
        MatrixBlock outputBlock = null;
        if (input.isEmpty()) {
            outputBlock = new MatrixBlock(C, 1, true);
        } else {
            outputBlock = new MatrixBlock(C, 1, false).allocateBlock();
            double[] output = outputBlock.getDenseBlockValues();
            if (input.isInSparseFormat()) {
                SparseBlock sblock = input.getSparseBlock();
                for (int n = 0; n < input.getNumRows(); ++n) {
                    if (sblock.isEmpty(n)) continue;
                    int apos = sblock.pos(n);
                    int alen = sblock.size(n);
                    int[] aix = sblock.indexes(n);
                    double[] avals = sblock.values(n);
                    for (int j = apos; j < apos + alen; ++j) {
                        int c;
                        int chw = aix[j];
                        int n2 = c = chw / HW;
                        output[n2] = output[n2] + avals[j];
                    }
                }
            } else {
                double[] inArr = input.getDenseBlockValues();
                if (inArr != null) {
                    KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
                    for (int c = 0; c < C; ++c) {
                        KahanObject sum = new KahanObject(0.0, 0.0);
                        for (int n = 0; n < input.getNumRows(); ++n) {
                            int index = n * C * HW + c * HW;
                            int hw = 0;
                            while (hw < HW) {
                                kplus.execute2(sum, inArr[index]);
                                ++hw;
                                ++index;
                            }
                        }
                        output[c] = sum._sum;
                    }
                }
            }
            outputBlock.recomputeNonZeros();
        }
        ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock, this.getExtendedOpcode());
    }

    private static boolean isFilterSparse(MatrixBlock filter) throws DMLRuntimeException {
        long numElems = filter.getNumRows() * filter.getNumColumns();
        if (filter.isInSparseFormat() && (double)numElems < 1.0E7) {
            filter.sparseToDense();
        }
        return filter.isInSparseFormat();
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        if (this.instOpcode.equalsIgnoreCase("bias_add")) {
            this.processBiasAddInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("bias_multiply")) {
            this.processBiasMultiplyInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("relu_backward")) {
            this.processReluBackwardInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("channel_sums")) {
            this.processChannelSumsInstruction(ec);
            return;
        }
        MatrixBlock outputBlock = null;
        MatrixBlock matBlock = this.instOpcode.equalsIgnoreCase("avgpooling_backward") ? null : ec.getMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        int pad_h = ConvolutionCPInstruction.getScalarInput(ec, this._padding, 0);
        int pad_w = ConvolutionCPInstruction.getScalarInput(ec, this._padding, 1);
        int stride_h = ConvolutionCPInstruction.getScalarInput(ec, this._stride, 0);
        int stride_w = ConvolutionCPInstruction.getScalarInput(ec, this._stride, 1);
        int N = ConvolutionCPInstruction.getScalarInput(ec, this._input_shape, 0);
        int C = ConvolutionCPInstruction.getScalarInput(ec, this._input_shape, 1);
        int H = ConvolutionCPInstruction.getScalarInput(ec, this._input_shape, 2);
        int W = ConvolutionCPInstruction.getScalarInput(ec, this._input_shape, 3);
        int K = ConvolutionCPInstruction.getScalarInput(ec, this._filter_shape, 0);
        int R = ConvolutionCPInstruction.getScalarInput(ec, this._filter_shape, 2);
        int S = ConvolutionCPInstruction.getScalarInput(ec, this._filter_shape, 3);
        int P = (int)ConvolutionUtils.getP(H, R, stride_h, pad_h);
        int Q = (int)ConvolutionUtils.getQ(W, S, stride_w, pad_w);
        ConvolutionParameters params = new ConvolutionParameters(N, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, this._numThreads);
        params.enableNative = NativeHelper.isNativeLibraryLoaded();
        if (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling") || this.instOpcode.equalsIgnoreCase("avgpooling")) {
            if (matBlock.isEmpty()) {
                outputBlock = new MatrixBlock(N, C * P * Q, true);
            } else {
                LibMatrixDNN.PoolingType poolType;
                outputBlock = new MatrixBlock(N, C * P * Q, false).allocateBlock();
                LibMatrixDNN.PoolingType poolingType = poolType = this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling") ? LibMatrixDNN.PoolingType.MAX : LibMatrixDNN.PoolingType.AVG;
                if (this.instOpcode.equalsIgnoreCase("relu_maxpooling")) {
                    params.minValForMaxPoolOperations = 0.0;
                }
                LibMatrixDNN.pooling(matBlock, outputBlock, params, poolType);
            }
        } else if (this.instOpcode.equalsIgnoreCase("maxpooling_backward") || this.instOpcode.equalsIgnoreCase("relu_maxpooling_backward") || this.instOpcode.equalsIgnoreCase("avgpooling_backward")) {
            boolean isEmpty;
            MatrixBlock dout = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
            boolean bl = this.instOpcode.equalsIgnoreCase("avgpooling_backward") ? dout.isEmpty() : (isEmpty = matBlock.isEmpty() || dout.isEmpty());
            if (isEmpty) {
                outputBlock = new MatrixBlock(N, C * H * W, true);
            } else {
                outputBlock = new MatrixBlock(N, C * H * W, false).allocateBlock();
                LibMatrixDNN.PoolingType poolType = this.instOpcode.equalsIgnoreCase("maxpooling_backward") || this.instOpcode.equalsIgnoreCase("relu_maxpooling_backward") ? LibMatrixDNN.PoolingType.MAX : LibMatrixDNN.PoolingType.AVG;
                boolean performReLUBackward = this.instOpcode.equalsIgnoreCase("relu_maxpooling_backward");
                if (performReLUBackward) {
                    params.minValForMaxPoolOperations = 0.0;
                }
                LibMatrixDNN.poolingBackward(matBlock, dout, outputBlock, params, performReLUBackward, poolType);
            }
            ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d")) {
            this.resetNumThreads(params, C * R * S, P * Q, matBlock.getNonZeros() / (long)(matBlock.getNumRows() * matBlock.getNumColumns()));
            MatrixBlock filter = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
            if (filter.isEmpty() || matBlock.isEmpty()) {
                outputBlock = new MatrixBlock(N, K * P * Q, true);
            } else {
                boolean sparse = matBlock.isUltraSparse(false) && params.bias == null && matBlock.getInMemorySize() < MatrixBlock.estimateSizeDenseInMemory(N, K * P * Q);
                outputBlock = new MatrixBlock(N, K * P * Q, sparse).allocateBlock();
                if (params.enableNative && !ConvolutionCPInstruction.isFilterSparse(filter) && !matBlock.isInSparseFormat()) {
                    LibMatrixNative.conv2d(matBlock, filter, outputBlock, params);
                } else {
                    LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
                }
            }
            ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
            boolean isOutputConvEmpty;
            this.resetNumThreads(params, C * R * S, P * Q, matBlock.getNonZeros() / (long)(matBlock.getNumRows() * matBlock.getNumColumns()));
            MatrixBlock filter = ec.getMatrixInput(this._in3.getName(), this.getExtendedOpcode());
            MatrixBlock bias = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
            if (bias.getNumRows() != params.K || bias.getNumColumns() != 1) {
                throw new DMLRuntimeException("Incorrect shape of bias matrix: [" + bias.getNumRows() + " " + bias.getNumColumns() + "]. Expected: [" + params.K + ", 1]");
            }
            boolean bl = isOutputConvEmpty = filter.isEmpty() || matBlock.isEmpty();
            if (isOutputConvEmpty && bias.isEmpty()) {
                outputBlock = new MatrixBlock(N, K * P * Q, true);
            } else if (isOutputConvEmpty && !bias.isEmpty()) {
                outputBlock = new MatrixBlock(N, K * P * Q, false).allocateBlock();
                for (int n = 0; n < params.N; ++n) {
                    ConvolutionUtils.fillBias(bias, outputBlock.getDenseBlockValues(), n, n + 1, params.N, params.K, params.P * params.Q);
                }
            } else {
                outputBlock = new MatrixBlock(N, K * P * Q, false).allocateBlock();
                if (!bias.isEmpty()) {
                    params.bias = bias;
                }
                if (params.enableNative && !ConvolutionCPInstruction.isFilterSparse(filter) && !matBlock.isInSparseFormat()) {
                    LibMatrixNative.conv2d(matBlock, filter, outputBlock, params);
                } else {
                    LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
                }
            }
            ec.releaseMatrixInput(this._in3.getName(), this.getExtendedOpcode());
            ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
            MatrixBlock dout = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
            if (dout.isEmpty() || matBlock.isEmpty()) {
                outputBlock = new MatrixBlock(K, C * R * S, true);
            } else {
                outputBlock = new MatrixBlock(K, C * R * S, false).allocateBlock();
                if (params.enableNative && !matBlock.isInSparseFormat() && !dout.isInSparseFormat()) {
                    LibMatrixNative.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
                } else {
                    LibMatrixDNN.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
                }
            }
            ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
            MatrixBlock dout = ec.getMatrixInput(this._in2.getName(), this.getExtendedOpcode());
            if (dout.isEmpty() || matBlock.isEmpty()) {
                outputBlock = new MatrixBlock(N, C * H * W, true);
            } else {
                outputBlock = new MatrixBlock(N, C * H * W, false).allocateBlock();
                if (params.enableNative && !ConvolutionCPInstruction.isFilterSparse(matBlock) && !dout.isInSparseFormat()) {
                    LibMatrixNative.conv2dBackwardData(matBlock, dout, outputBlock, params);
                } else {
                    LibMatrixDNN.conv2dBackwardData(matBlock, dout, outputBlock, params);
                }
            }
            ec.releaseMatrixInput(this._in2.getName(), this.getExtendedOpcode());
        } else {
            throw new DMLRuntimeException("Unsupported op code " + this.instOpcode);
        }
        if (!this.instOpcode.equalsIgnoreCase("avgpooling_backward")) {
            ec.releaseMatrixInput(this.input1.getName(), this.getExtendedOpcode());
        }
        ec.setMatrixOutput(this.getOutputVariableName(), outputBlock, this.getExtendedOpcode());
    }

    private void resetNumThreads(ConvolutionParameters params, int numRows, int numCols, double sparsity) {
        double memBudget1Thread;
        int limitedDegreeOfParallelism;
        if (DMLScript.USE_ACCELERATOR && params.numThreads > (limitedDegreeOfParallelism = (int)Math.floor(this._intermediateMemoryBudget / (memBudget1Thread = (double)OptimizerUtils.estimateSizeExactSparsity((long)numRows, (long)numCols, sparsity))))) {
            params.numThreads = limitedDegreeOfParallelism;
            if (!warnedUnderUtilitization) {
                LOG.warn("CPU Under-utilization to respect the intermediate memory budget. To avoid this, please try reducing the mini-batch or forcing gpu execution.");
            }
            warnedUnderUtilitization = true;
        }
    }
}

