/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.gpu;

import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysml.runtime.matrix.data.LibMatrixCuMatMult;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.utils.GPUStatistics;

public class AggregateBinaryGPUInstruction
extends GPUInstruction {
    private CPOperand _input1 = null;
    private CPOperand _input2 = null;
    private CPOperand _output = null;
    private boolean _isLeftTransposed;
    private boolean _isRightTransposed;

    private AggregateBinaryGPUInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, boolean leftTranspose, boolean rightTranspose) {
        super(op, opcode, istr);
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.AggregateBinary;
        this._input1 = in1;
        this._input2 = in2;
        this._output = out;
        this._isLeftTransposed = leftTranspose;
        this._isRightTransposed = rightTranspose;
    }

    public static AggregateBinaryGPUInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("ba+*")) {
            throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        InstructionUtils.checkNumFields(parts, 5);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        boolean isLeftTransposed = Boolean.parseBoolean(parts[4]);
        boolean isRightTransposed = Boolean.parseBoolean(parts[5]);
        AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
        AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg, 1);
        return new AggregateBinaryGPUInstruction(aggbin, in1, in2, out, opcode, str, isLeftTransposed, isRightTransposed);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        AggregateBinaryOperator op = (AggregateBinaryOperator)this._optr;
        if (!(op.binaryFn instanceof Multiply) || !(op.aggOp.increOp.fn instanceof Plus)) {
            throw new DMLRuntimeException("Unsupported binary aggregate operation: (" + op.binaryFn + ", " + op.aggOp + ").");
        }
        MatrixObject m1 = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        MatrixObject m2 = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        int rlen = (int)(this._isLeftTransposed ? m1.getNumColumns() : m1.getNumRows());
        int clen = (int)(this._isRightTransposed ? m2.getNumRows() : m2.getNumColumns());
        ec.setMetaData(this._output.getName(), rlen, clen);
        LibMatrixCuMatMult.matmult(ec, ec.getGPUContext(0), this.getExtendedOpcode(), m1, m2, this._output.getName(), this._isLeftTransposed, this._isRightTransposed);
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private static MatrixBlock transpose(MatrixBlock m1) throws DMLRuntimeException {
        ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), 1);
        return (MatrixBlock)m1.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
    }

    private static boolean isSparse(ExecutionContext ec, String var) throws DMLRuntimeException {
        MatrixObject mo = ec.getMatrixObject(var);
        return LibMatrixCUDA.isInSparseFormat(ec.getGPUContext(0), mo);
    }
}

