/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import scala.Tuple2;

public class ZipmmSPInstruction
extends BinarySPInstruction {
    private boolean _tRewrite = true;

    private ZipmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, boolean tRewrite, String opcode, String istr) {
        super(op, in1, in2, out, opcode, istr);
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.ZIPMM;
        this._tRewrite = tRewrite;
    }

    public static ZipmmSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("zipmm")) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            boolean tRewrite = Boolean.parseBoolean(parts[4]);
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
            return new ZipmmSPInstruction((Operator)aggbin, in1, in2, out, tRewrite, opcode, str);
        }
        throw new DMLRuntimeException("ZipmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable(this.input2.getName());
        JavaRDD out = in1.join(in2).values().map((Function)new ZipMultiplyFunction(this._tRewrite));
        MatrixBlock out2 = RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>)out);
        if (this._tRewrite) {
            ReorgOperator rop = new ReorgOperator(SwapIndex.getSwapIndexFnObject());
            out2 = (MatrixBlock)out2.reorgOperations(rop, new MatrixBlock(), 0, 0, 0);
        }
        sec.setMatrixOutput(this.output.getName(), out2, this.getExtendedOpcode());
    }

    private static class ZipMultiplyFunction
    implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = -6669267794926220287L;
        private AggregateBinaryOperator _abop = null;
        private ReorgOperator _rop = null;
        private boolean _tRewrite = true;

        public ZipMultiplyFunction(boolean tRewrite) {
            this._tRewrite = tRewrite;
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            this._abop = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
            this._rop = new ReorgOperator(SwapIndex.getSwapIndexFnObject());
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> arg0) throws Exception {
            MatrixBlock in1 = this._tRewrite ? (MatrixBlock)arg0._1() : (MatrixBlock)arg0._2();
            MatrixBlock in2 = this._tRewrite ? (MatrixBlock)arg0._2() : (MatrixBlock)arg0._1();
            MatrixBlock tmp = (MatrixBlock)in2.reorgOperations(this._rop, new MatrixBlock(), 0, 0, 0);
            return (MatrixBlock)tmp.aggregateBinaryOperations(tmp, in1, new MatrixBlock(), this._abop);
        }
    }
}

