/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.gpu;

import java.util.ArrayList;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.runtime.util.ConvolutionUtils;
import org.apache.sysml.utils.GPUStatistics;

public class ConvolutionGPUInstruction
extends GPUInstruction {
    private CPOperand _input1;
    private CPOperand _input2;
    private CPOperand _input3;
    private CPOperand _output;
    private ArrayList<CPOperand> _input_shape;
    private ArrayList<CPOperand> _filter_shape;
    private ArrayList<CPOperand> _stride = new ArrayList();
    private ArrayList<CPOperand> _padding = new ArrayList();

    private ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) throws DMLRuntimeException {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
        if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward"))) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + opcode);
        }
        this._input1 = in1;
        this._input2 = in2;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Convolution;
        this._output = out;
    }

    private ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape) {
        this(in1, in2, out, opcode, istr, stride, padding, input_shape, filter_shape);
        this._input3 = in3;
    }

    private ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Convolution;
        this._input1 = in1;
        this._input2 = in2;
        this._output = out;
        this._stride = stride;
        this._padding = padding;
        this._input_shape = input_shape;
        this._filter_shape = filter_shape;
    }

    public static ConvolutionGPUInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("conv2d") || opcode.equalsIgnoreCase("conv2d_backward_filter") || opcode.equalsIgnoreCase("conv2d_backward_data") || opcode.equalsIgnoreCase("maxpooling_backward")) {
            InstructionUtils.checkNumFields(parts, 15);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[15]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[3]));
            stride.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            return new ConvolutionGPUInstruction(in1, in2, out, opcode, str, stride, padding, input_shape, filter_shape);
        }
        if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
            InstructionUtils.checkNumFields(parts, 16);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand out = new CPOperand(parts[16]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[4]));
            stride.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            padding.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            input_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            filter_shape.add(new CPOperand(parts[15]));
            return new ConvolutionGPUInstruction(in1, in2, in3, out, opcode, str, stride, padding, input_shape, filter_shape);
        }
        if (opcode.equalsIgnoreCase("maxpooling")) {
            InstructionUtils.checkNumFields(parts, 14);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand out = new CPOperand(parts[14]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[2]));
            stride.add(new CPOperand(parts[3]));
            padding.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            input_shape.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            filter_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            return new ConvolutionGPUInstruction(in1, null, out, opcode, str, stride, padding, input_shape, filter_shape);
        }
        if (opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply")) {
            InstructionUtils.checkNumFields(parts, 3);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            return new ConvolutionGPUInstruction(in1, in2, out, opcode, str);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionGPUInstruction: " + str);
    }

    public void processBiasInstruction(String instOpcode, ExecutionContext ec) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject input = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        MatrixObject bias = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), input.getNumRows(), input.getNumColumns());
        if (instOpcode.equalsIgnoreCase("bias_add")) {
            LibMatrixCUDA.biasAdd(ec.getGPUContext(0), this.getExtendedOpcode(), input, bias, out);
        } else if (instOpcode.equalsIgnoreCase("bias_multiply")) {
            LibMatrixCUDA.biasMultiply(ec.getGPUContext(0), this.getExtendedOpcode(), input, bias, out);
        }
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    public void processReLUBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject input = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        MatrixObject dout = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), input.getNumRows(), input.getNumColumns());
        LibMatrixCUDA.reluBackward(ec.getGPUContext(0), this.getExtendedOpcode(), input, dout, out);
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        if (this.instOpcode.equalsIgnoreCase("bias_add") || this.instOpcode.equalsIgnoreCase("bias_multiply")) {
            this.processBiasInstruction(this.instOpcode, ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("relu_backward")) {
            this.processReLUBackwardInstruction(ec);
            return;
        }
        GPUStatistics.incrementNoOfExecutedGPUInst();
        int pad_h = this.getScalarInput(ec, this._padding, 0);
        int pad_w = this.getScalarInput(ec, this._padding, 1);
        int stride_h = this.getScalarInput(ec, this._stride, 0);
        int stride_w = this.getScalarInput(ec, this._stride, 1);
        int N = this.getScalarInput(ec, this._input_shape, 0);
        int C = this.getScalarInput(ec, this._input_shape, 1);
        int H = this.getScalarInput(ec, this._input_shape, 2);
        int W = this.getScalarInput(ec, this._input_shape, 3);
        int K = this.getScalarInput(ec, this._filter_shape, 0);
        int R = this.getScalarInput(ec, this._filter_shape, 2);
        int S = this.getScalarInput(ec, this._filter_shape, 3);
        int P = (int)ConvolutionUtils.getP(H, R, stride_h, pad_h);
        int Q = (int)ConvolutionUtils.getQ(W, S, stride_w, pad_w);
        if (this.instOpcode.equalsIgnoreCase("conv2d")) {
            MatrixObject image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            MatrixObject filter = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
            if (image.getNumRows() != (long)N || image.getNumColumns() != (long)(C * H * W)) {
                throw new DMLRuntimeException("Incorrect dimensions for image in conv2d");
            }
            if (filter.getNumRows() != (long)K || filter.getNumColumns() != (long)(C * R * S)) {
                throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d");
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), N, K * P * Q);
            LibMatrixCUDA.conv2d(ec.getGPUContext(0), this.getExtendedOpcode(), image, filter, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
            MatrixObject image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            MatrixObject bias = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
            MatrixObject filter = this.getMatrixInputForGPUInstruction(ec, this._input3.getName());
            if (image.getNumRows() != (long)N || image.getNumColumns() != (long)(C * H * W)) {
                throw new DMLRuntimeException("Incorrect dimensions for image in conv2d");
            }
            if (filter.getNumRows() != (long)K || filter.getNumColumns() != (long)(C * R * S)) {
                throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d");
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), N, K * P * Q);
            LibMatrixCUDA.conv2dBiasAdd(ec.getGPUContext(0), this.getExtendedOpcode(), image, bias, filter, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
            MatrixObject image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            MatrixObject dout = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
            if (image.getNumRows() != (long)N || image.getNumColumns() != (long)(C * H * W)) {
                throw new DMLRuntimeException("Incorrect dimensions for image in conv2d_backward_filter");
            }
            if (dout.getNumRows() != (long)N || dout.getNumColumns() != (long)(K * P * Q)) {
                throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_filter: " + dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K * P * Q);
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), K, C * R * S);
            LibMatrixCUDA.conv2dBackwardFilter(ec.getGPUContext(0), this.getExtendedOpcode(), image, dout, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
            MatrixObject filter = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            MatrixObject dout = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
            if (filter.getNumRows() != (long)K || filter.getNumColumns() != (long)(C * R * S)) {
                throw new DMLRuntimeException("Incorrect dimensions for filter in convolution_backward_data");
            }
            if (dout.getNumRows() != (long)N || dout.getNumColumns() != (long)(K * P * Q)) {
                throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_data: " + dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K * P * Q);
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), N, C * H * W);
            LibMatrixCUDA.conv2dBackwardData(ec.getGPUContext(0), this.getExtendedOpcode(), filter, dout, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
        } else if (this.instOpcode.equalsIgnoreCase("maxpooling")) {
            MatrixObject image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            if (image.getNumRows() != (long)N || image.getNumColumns() != (long)(C * H * W)) {
                throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling: " + image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + C * H * W);
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), N, C * P * Q);
            if (this.instOpcode.equalsIgnoreCase("maxpooling")) {
                LibMatrixCUDA.maxpooling(ec.getGPUContext(0), this.getExtendedOpcode(), image, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
            }
        } else if (this.instOpcode.equalsIgnoreCase("maxpooling_backward")) {
            MatrixObject image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            MatrixObject dout = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
            if (dout.getNumRows() != (long)N || dout.getNumColumns() != (long)(C * P * Q)) {
                throw new DMLRuntimeException("Incorrect dimensions for dout in maxpooling_backward");
            }
            if (image.getNumRows() != (long)N || image.getNumColumns() != (long)(C * H * W)) {
                throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling_backward: " + image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + K * P * Q);
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), N, C * H * W);
            LibMatrixCUDA.maxpoolingBackward(ec.getGPUContext(0), this.getExtendedOpcode(), image, dout, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
        } else {
            throw new DMLRuntimeException("Unsupported GPU context for " + this.instOpcode);
        }
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        if (!this.instOpcode.equalsIgnoreCase("maxpooling")) {
            ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        }
        if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
            ec.releaseMatrixInputForGPUInstruction(this._input3.getName());
        }
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) throws DMLRuntimeException {
        return (int)ec.getScalarInput(aL.get(index).getName(), aL.get(index).getValueType(), aL.get(index).isLiteral()).getLongValue();
    }
}

