/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.IndexFunction;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.LibMatrixOuterAgg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.util.DataConverter;
import scala.Tuple2;

public class UaggOuterChainSPInstruction
extends BinarySPInstruction {
    private AggregateUnaryOperator _uaggOp = null;
    private AggregateOperator _aggOp = null;
    private BinaryOperator _bOp = null;

    private UaggOuterChainSPInstruction(BinaryOperator bop, AggregateUnaryOperator uaggop, AggregateOperator aggop, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(bop, in1, in2, out, opcode, istr);
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.UaggOuterChain;
        this._uaggOp = uaggop;
        this._aggOp = aggop;
        this._bOp = bop;
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.UaggOuterChain;
        this.instString = istr;
    }

    public static UaggOuterChainSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("uaggouterchain")) {
            AggregateUnaryOperator uaggop = InstructionUtils.parseBasicAggregateUnaryOperator(parts[1]);
            BinaryOperator bop = InstructionUtils.parseBinaryOperator(parts[2]);
            CPOperand in1 = new CPOperand(parts[3]);
            CPOperand in2 = new CPOperand(parts[4]);
            CPOperand out = new CPOperand(parts[5]);
            String aopcode = InstructionUtils.deriveAggregateOperatorOpcode(parts[1]);
            PartialAggregate.CorrectionLocationType corrLoc = InstructionUtils.deriveAggregateOperatorCorrectionLocation(parts[1]);
            String corrExists = corrLoc != PartialAggregate.CorrectionLocationType.NONE ? "true" : "false";
            AggregateOperator aop = InstructionUtils.parseAggregateOperator(aopcode, corrExists, corrLoc.toString());
            return new UaggOuterChainSPInstruction(bop, uaggop, aop, in1, in2, out, opcode, str);
        }
        throw new DMLRuntimeException("UaggOuterChainSPInstruction.parseInstruction():: Unknown opcode " + opcode);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        boolean rightCached = this._uaggOp.indexFn instanceof ReduceCol || this._uaggOp.indexFn instanceof ReduceAll || !LibMatrixOuterAgg.isSupportedUaggOp(this._uaggOp, this._bOp);
        String rddVar = rightCached ? this.input1.getName() : this.input2.getName();
        String bcastVar = rightCached ? this.input2.getName() : this.input1.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
        MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddVar);
        boolean noKeyChange = UaggOuterChainSPInstruction.preservesPartitioning(mcIn, this._uaggOp.indexFn);
        JavaPairRDD out = null;
        if (LibMatrixOuterAgg.isSupportedUaggOp(this._uaggOp, this._bOp)) {
            MatrixBlock mb = sec.getMatrixInput(bcastVar, this.getExtendedOpcode());
            sec.releaseMatrixInput(bcastVar, this.getExtendedOpcode());
            bcastVar = null;
            double[] vmb = DataConverter.convertToDoubleVector(mb);
            Broadcast bvi = null;
            if (this._uaggOp.aggOp.increOp.fn instanceof Builtin) {
                int[] vix = LibMatrixOuterAgg.prepareRowIndices(mb.getNumColumns(), vmb, this._bOp, this._uaggOp);
                bvi = sec.getSparkContext().broadcast((Object)vix);
            } else {
                Arrays.sort(vmb);
            }
            Broadcast bv = sec.getSparkContext().broadcast((Object)vmb);
            out = in1.mapPartitionsToPair((PairFlatMapFunction)new RDDMapUAggOuterChainFunction((Broadcast<double[]>)bv, (Broadcast<int[]>)bvi, this._bOp, this._uaggOp), noKeyChange);
        } else {
            PartitionedBroadcast<MatrixBlock> bv = sec.getBroadcastForVariable(bcastVar);
            out = in1.mapPartitionsToPair((PairFlatMapFunction)new RDDMapGenUAggOuterChainFunction(bv, this._uaggOp, this._aggOp, this._bOp, mcIn), noKeyChange);
        }
        if (this._uaggOp.indexFn instanceof ReduceAll) {
            MatrixBlock tmp = RDDAggregateUtils.aggStable((JavaPairRDD<MatrixIndexes, MatrixBlock>)out, this._aggOp);
            tmp.dropLastRowsOrColumns(this._aggOp.correctionLocation);
            sec.setMatrixOutput(this.output.getName(), tmp, this.getExtendedOpcode());
        } else {
            this.updateUnaryAggOutputMatrixCharacteristics(sec);
            if (this._uaggOp.aggOp.correctionExists) {
                out = out.mapValues((Function)new AggregateDropCorrectionFunction(this._uaggOp.aggOp));
            }
            sec.setRDDHandleForVariable(this.output.getName(), out);
            sec.addLineageRDD(this.output.getName(), rddVar);
            if (bcastVar != null) {
                sec.addLineageBroadcast(this.output.getName(), bcastVar);
            }
        }
    }

    protected static boolean preservesPartitioning(MatrixCharacteristics mcIn, IndexFunction ixfun) {
        if (ixfun instanceof ReduceCol) {
            return mcIn.dimsKnown() && mcIn.getCols() <= (long)mcIn.getColsPerBlock();
        }
        return mcIn.dimsKnown() && mcIn.getRows() <= (long)mcIn.getRowsPerBlock();
    }

    protected void updateUnaryAggOutputMatrixCharacteristics(SparkExecutionContext sec) throws DMLRuntimeException {
        String strInput2Name;
        String strInput1Name;
        if (this._uaggOp.indexFn instanceof ReduceCol) {
            strInput1Name = this.input1.getName();
            strInput2Name = this.input2.getName();
        } else {
            strInput1Name = this.input2.getName();
            strInput2Name = this.input1.getName();
        }
        MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(strInput1Name);
        MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(strInput2Name);
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this.output.getName());
        if (!mcOut.dimsKnown()) {
            if (!mc1.dimsKnown()) {
                throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from input:" + mc1.toString() + " " + mcOut.toString());
            }
            if (this._uaggOp.indexFn instanceof ReduceAll) {
                mcOut.set(1L, 1L, mc1.getRowsPerBlock(), mc1.getColsPerBlock());
            } else if (this._uaggOp.indexFn instanceof ReduceCol) {
                mcOut.set(mc1.getRows(), 1L, mc1.getRowsPerBlock(), mc1.getColsPerBlock());
            } else if (this._uaggOp.indexFn instanceof ReduceRow) {
                mcOut.set(1L, mc2.getCols(), mc1.getRowsPerBlock(), mc2.getColsPerBlock());
            }
        }
    }

    private static class RDDMapGenUAggOuterChainFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 8197406787010296291L;
        private PartitionedBroadcast<MatrixBlock> _pbc = null;
        private AggregateUnaryOperator _uaggOp = null;
        private AggregateOperator _aggOp = null;
        private BinaryOperator _bOp = null;
        private int _brlen;
        private int _bclen;
        private MatrixValue _tmpVal1 = null;
        private MatrixValue _tmpVal2 = null;

        public RDDMapGenUAggOuterChainFunction(PartitionedBroadcast<MatrixBlock> binput, AggregateUnaryOperator uaggOp, AggregateOperator aggOp, BinaryOperator bOp, MatrixCharacteristics mc) {
            this._pbc = binput;
            this._uaggOp = uaggOp;
            this._aggOp = aggOp;
            this._bOp = bOp;
            this._brlen = mc.getRowsPerBlock();
            this._bclen = mc.getColsPerBlock();
            this._tmpVal1 = new MatrixBlock();
            this._tmpVal2 = new MatrixBlock();
        }

        public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg) throws Exception {
            return new RDDMapGenUAggOuterChainIterator(arg);
        }

        private class RDDMapGenUAggOuterChainIterator
        extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public RDDMapGenUAggOuterChainIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
                super(in);
            }

            @Override
            protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
                MatrixIndexes in1Ix = (MatrixIndexes)arg._1();
                MatrixBlock in1Val = (MatrixBlock)arg._2();
                MatrixIndexes outIx = new MatrixIndexes();
                MatrixBlock outVal = new MatrixBlock();
                MatrixBlock corr = null;
                long in2_colBlocks = RDDMapGenUAggOuterChainFunction.this._pbc.getNumColumnBlocks();
                int bidx = 1;
                while ((long)bidx <= in2_colBlocks) {
                    MatrixValue in2Val = (MatrixValue)RDDMapGenUAggOuterChainFunction.this._pbc.getBlock(1, bidx);
                    OperationsOnMatrixValues.performBinaryIgnoreIndexes(in1Val, in2Val, RDDMapGenUAggOuterChainFunction.this._tmpVal1, RDDMapGenUAggOuterChainFunction.this._bOp);
                    OperationsOnMatrixValues.performAggregateUnary(in1Ix, RDDMapGenUAggOuterChainFunction.this._tmpVal1, outIx, RDDMapGenUAggOuterChainFunction.this._tmpVal2, RDDMapGenUAggOuterChainFunction.this._uaggOp, RDDMapGenUAggOuterChainFunction.this._brlen, RDDMapGenUAggOuterChainFunction.this._bclen);
                    if (corr == null) {
                        outVal.reset(RDDMapGenUAggOuterChainFunction.this._tmpVal2.getNumRows(), RDDMapGenUAggOuterChainFunction.this._tmpVal2.getNumColumns(), false);
                        corr = new MatrixBlock(RDDMapGenUAggOuterChainFunction.this._tmpVal2.getNumRows(), RDDMapGenUAggOuterChainFunction.this._tmpVal2.getNumColumns(), false);
                    }
                    if (((RDDMapGenUAggOuterChainFunction)RDDMapGenUAggOuterChainFunction.this)._aggOp.correctionExists) {
                        OperationsOnMatrixValues.incrementalAggregation(outVal, corr, RDDMapGenUAggOuterChainFunction.this._tmpVal2, RDDMapGenUAggOuterChainFunction.this._aggOp, true);
                    } else {
                        OperationsOnMatrixValues.incrementalAggregation(outVal, null, RDDMapGenUAggOuterChainFunction.this._tmpVal2, RDDMapGenUAggOuterChainFunction.this._aggOp, true);
                    }
                    ++bidx;
                }
                return new Tuple2((Object)outIx, (Object)outVal);
            }
        }
    }

    private static class RDDMapUAggOuterChainFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 8197406787010296291L;
        private Broadcast<double[]> _bv = null;
        private Broadcast<int[]> _bvi = null;
        private BinaryOperator _bOp = null;
        private AggregateUnaryOperator _uaggOp = null;

        public RDDMapUAggOuterChainFunction(Broadcast<double[]> bv, Broadcast<int[]> bvi, BinaryOperator bOp, AggregateUnaryOperator uaggOp) {
            this._bv = bv;
            this._bvi = bvi;
            this._bOp = bOp;
            this._uaggOp = uaggOp;
        }

        public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) throws Exception {
            return new RDDMapUAggOuterChainIterator(arg0);
        }

        private class RDDMapUAggOuterChainIterator
        extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public RDDMapUAggOuterChainIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
                super(in);
            }

            @Override
            protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
                MatrixIndexes in1Ix = (MatrixIndexes)arg._1();
                MatrixBlock in1Val = (MatrixBlock)arg._2();
                MatrixIndexes outIx = new MatrixIndexes();
                MatrixBlock outVal = new MatrixBlock();
                int[] bvi = null;
                if (LibMatrixOuterAgg.isRowIndexMax(RDDMapUAggOuterChainFunction.this._uaggOp) || LibMatrixOuterAgg.isRowIndexMin(RDDMapUAggOuterChainFunction.this._uaggOp)) {
                    bvi = (int[])RDDMapUAggOuterChainFunction.this._bvi.getValue();
                }
                LibMatrixOuterAgg.resetOutputMatix(in1Ix, in1Val, outIx, outVal, RDDMapUAggOuterChainFunction.this._uaggOp);
                LibMatrixOuterAgg.aggregateMatrix(in1Val, outVal, (double[])RDDMapUAggOuterChainFunction.this._bv.value(), bvi, RDDMapUAggOuterChainFunction.this._bOp, RDDMapUAggOuterChainFunction.this._uaggOp);
                return new Tuple2((Object)outIx, (Object)outVal);
            }
        }
    }
}

