/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.mr;

import java.util.ArrayList;
import org.apache.sysml.lops.MMCJ;
import org.apache.sysml.lops.MapMult;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.mr.BinaryMRInstructionBase;
import org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;

public class AggregateBinaryInstruction
extends BinaryMRInstructionBase
implements IDistributedCacheConsumer {
    private String _opcode = null;
    private MMCJ.MMCJType _aggType = MMCJ.MMCJType.AGG;
    private MapMult.CacheType _cacheType = null;
    private boolean _outputEmptyBlocks = true;

    private AggregateBinaryInstruction(Operator op, String opcode, byte in1, byte in2, byte out, String istr) {
        super(op, in1, in2, out);
        this.mrtype = MRInstruction.MRINSTRUCTION_TYPE.AggregateBinary;
        this.instString = istr;
        this._opcode = opcode;
    }

    public void setCacheTypeMapMult(MapMult.CacheType type) {
        this._cacheType = type;
    }

    public void setOutputEmptyBlocksMapMult(boolean flag) {
        this._outputEmptyBlocks = flag;
    }

    public boolean getOutputEmptyBlocks() {
        return this._outputEmptyBlocks;
    }

    public void setMMCJType(MMCJ.MMCJType type) {
        this._aggType = type;
    }

    public MMCJ.MMCJType getMMCJType() {
        return this._aggType;
    }

    public static AggregateBinaryInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionParts(str);
        String opcode = parts[0];
        byte in1 = Byte.parseByte(parts[1]);
        byte in2 = Byte.parseByte(parts[2]);
        byte out = Byte.parseByte(parts[3]);
        if (opcode.equalsIgnoreCase("cpmm") || opcode.equalsIgnoreCase("rmm") || opcode.equalsIgnoreCase("mapmm")) {
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
            AggregateBinaryInstruction inst = new AggregateBinaryInstruction(aggbin, opcode, in1, in2, out, str);
            if (parts.length == 5) {
                inst.setMMCJType(MMCJ.MMCJType.valueOf(parts[4]));
            } else if (parts.length == 6) {
                inst.setCacheTypeMapMult(MapMult.CacheType.valueOf(parts[4]));
                inst.setOutputEmptyBlocksMapMult(Boolean.parseBoolean(parts[5]));
            }
            return inst;
        }
        throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
    }

    @Override
    public boolean isDistCacheOnlyIndex(String inst, byte index) {
        return this._cacheType.isRight() ? index == this.input2 && index != this.input1 : index == this.input1 && index != this.input2;
    }

    @Override
    public void addDistCacheIndex(String inst, ArrayList<Byte> indexes) {
        indexes.add(this._cacheType.isRight() ? this.input2 : this.input1);
    }

    @Override
    public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) throws DMLRuntimeException {
        IndexedMatrixValue in1 = cachedValues.getFirst(this.input1);
        IndexedMatrixValue in2 = cachedValues.getFirst(this.input2);
        if (this._opcode.equals("mapmm")) {
            if (this._cacheType.isRight() ? in1 == null : in2 == null) {
                return;
            }
            this.processMapMultInstruction(valueClass, cachedValues, in1, in2, blockRowFactor, blockColFactor);
        } else {
            if (in1 == null || in2 == null) {
                return;
            }
            IndexedMatrixValue out = this.output == this.input1 || this.output == this.input2 ? tempValue : cachedValues.holdPlace(this.output, valueClass);
            OperationsOnMatrixValues.performAggregateBinary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(), out.getIndexes(), out.getValue(), (AggregateBinaryOperator)this.optr);
            if (out == tempValue) {
                cachedValues.add(this.output, out);
            }
        }
    }

    private void processMapMultInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue in1, IndexedMatrixValue in2, int blockRowFactor, int blockColFactor) throws DMLRuntimeException {
        boolean removeOutput = true;
        if (this._cacheType.isRight()) {
            DistributedCacheInput dcInput = MRBaseForCommonInstructions.dcValues.get(this.input2);
            long in2_cols = dcInput.getNumCols();
            long in2_colBlocks = (long)Math.ceil((double)in2_cols / (double)dcInput.getNumColsPerBlock());
            int bidx = 1;
            while ((long)bidx <= in2_colBlocks) {
                IndexedMatrixValue in2Block = dcInput.getDataBlock((int)in1.getIndexes().getColumnIndex(), bidx);
                MatrixValue in2BlockValue = in2Block.getValue();
                MatrixIndexes in2BlockIndex = in2Block.getIndexes();
                IndexedMatrixValue out = cachedValues.holdPlace(this.output, valueClass);
                OperationsOnMatrixValues.performAggregateBinary(in1.getIndexes(), in1.getValue(), in2BlockIndex, in2BlockValue, out.getIndexes(), out.getValue(), (AggregateBinaryOperator)this.optr);
                removeOutput &= !this._outputEmptyBlocks && out.getValue().isEmpty();
                ++bidx;
            }
        } else {
            DistributedCacheInput dcInput = MRBaseForCommonInstructions.dcValues.get(this.input1);
            long in1_rows = dcInput.getNumRows();
            long in1_rowsBlocks = (long)Math.ceil((double)in1_rows / (double)dcInput.getNumRowsPerBlock());
            int bidx = 1;
            while ((long)bidx <= in1_rowsBlocks) {
                IndexedMatrixValue in1Block = dcInput.getDataBlock(bidx, (int)in2.getIndexes().getRowIndex());
                MatrixValue in1BlockValue = in1Block.getValue();
                MatrixIndexes in1BlockIndex = in1Block.getIndexes();
                IndexedMatrixValue out = cachedValues.holdPlace(this.output, valueClass);
                OperationsOnMatrixValues.performAggregateBinary(in1BlockIndex, in1BlockValue, in2.getIndexes(), in2.getValue(), out.getIndexes(), out.getValue(), (AggregateBinaryOperator)this.optr);
                removeOutput &= !this._outputEmptyBlocks && out.getValue().isEmpty();
                ++bidx;
            }
        }
        if (removeOutput) {
            cachedValues.remove(this.output);
        }
    }
}

