/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.sysml.lops.MapMultChain;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.Operator;
import scala.Tuple2;

public class MapmmChainSPInstruction
extends SPInstruction {
    private MapMultChain.ChainType _chainType = null;
    private CPOperand _input1 = null;
    private CPOperand _input2 = null;
    private CPOperand _input3 = null;
    private CPOperand _output = null;

    private MapmmChainSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, MapMultChain.ChainType type, String opcode, String istr) {
        super(SPInstruction.SPType.MAPMMCHAIN, op, opcode, istr);
        this._input1 = in1;
        this._input2 = in2;
        this._output = out;
        this._chainType = type;
    }

    private MapmmChainSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, MapMultChain.ChainType type, String opcode, String istr) {
        super(SPInstruction.SPType.MAPMMCHAIN, op, opcode, istr);
        this._input1 = in1;
        this._input2 = in2;
        this._input3 = in3;
        this._output = out;
        this._chainType = type;
    }

    public static MapmmChainSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(parts, 4, 5);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("mapmmchain")) {
            throw new DMLRuntimeException("MapmmChainSPInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        if (parts.length == 5) {
            CPOperand out = new CPOperand(parts[3]);
            MapMultChain.ChainType type = MapMultChain.ChainType.valueOf(parts[4]);
            return new MapmmChainSPInstruction(null, in1, in2, out, type, opcode, str);
        }
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4]);
        MapMultChain.ChainType type = MapMultChain.ChainType.valueOf(parts[5]);
        return new MapmmChainSPInstruction(null, in1, in2, in3, out, type, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inX = sec.getBinaryBlockRDDHandleForVariable(this._input1.getName());
        PartitionedBroadcast<MatrixBlock> inV = sec.getBroadcastForVariable(this._input2.getName());
        MatrixBlock out = null;
        if (this._chainType == MapMultChain.ChainType.XtXv) {
            JavaRDD tmp = inX.values().map((Function)new RDDMapMMChainFunction(inV));
            out = RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>)tmp);
        } else {
            PartitionedBroadcast<MatrixBlock> inW = sec.getBroadcastForVariable(this._input3.getName());
            JavaRDD tmp = inX.map((Function)new RDDMapMMChainFunction2(inV, inW, this._chainType));
            out = RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>)tmp);
        }
        sec.setMatrixOutput(this._output.getName(), out, this.getExtendedOpcode());
    }

    private static class RDDMapMMChainFunction2
    implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = -7926980450209760212L;
        private PartitionedBroadcast<MatrixBlock> _pmV = null;
        private PartitionedBroadcast<MatrixBlock> _pmW = null;
        private MapMultChain.ChainType _chainType = null;

        public RDDMapMMChainFunction2(PartitionedBroadcast<MatrixBlock> bV, PartitionedBroadcast<MatrixBlock> bW, MapMultChain.ChainType chain) throws DMLRuntimeException {
            this._pmV = bV;
            this._pmW = bW;
            this._chainType = chain;
        }

        public MatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixBlock pmV = this._pmV.getBlock(1, 1);
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            int rowIx = (int)ixIn.getRowIndex();
            return blkIn.chainMatrixMultOperations(pmV, this._pmW.getBlock(rowIx, 1), new MatrixBlock(), this._chainType);
        }
    }

    private static class RDDMapMMChainFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 8197406787010296291L;
        private PartitionedBroadcast<MatrixBlock> _pmV = null;

        public RDDMapMMChainFunction(PartitionedBroadcast<MatrixBlock> bV) throws DMLRuntimeException {
            this._pmV = bV;
        }

        public MatrixBlock call(MatrixBlock arg0) throws Exception {
            MatrixBlock pmV = this._pmV.getBlock(1, 1);
            return arg0.chainMatrixMultOperations(pmV, null, new MatrixBlock(), MapMultChain.ChainType.XtXv);
        }
    }
}

