/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.lops.Ternary;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.CTable;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.CTableMap;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.SimpleOperator;
import org.apache.sysml.runtime.util.LongLongDoubleHashMap;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

public class TernarySPInstruction
extends ComputationSPInstruction {
    private String _outDim1;
    private String _outDim2;
    private boolean _dim1Literal;
    private boolean _dim2Literal;
    private boolean _isExpand;
    private boolean _ignoreZeros;

    private TernarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand, boolean ignoreZeros, String opcode, String istr) {
        super(op, in1, in2, in3, out, opcode, istr);
        this._outDim1 = outputDim1;
        this._dim1Literal = dim1Literal;
        this._outDim2 = outputDim2;
        this._dim2Literal = dim2Literal;
        this._isExpand = isExpand;
        this._ignoreZeros = ignoreZeros;
    }

    public static TernarySPInstruction parseInstruction(String inst) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
        InstructionUtils.checkNumFields(parts, 7);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("ctable") && !opcode.equalsIgnoreCase("ctableexpand")) {
            throw new DMLRuntimeException("Unexpected opcode in TertiarySPInstruction: " + inst);
        }
        boolean isExpand = opcode.equalsIgnoreCase("ctableexpand");
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        String[] dim1Fields = parts[4].split("\u00b7");
        String[] dim2Fields = parts[5].split("\u00b7");
        CPOperand out = new CPOperand(parts[6]);
        boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
        return new TernarySPInstruction(new SimpleOperator(null), in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        boolean findDimensions;
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in3 = null;
        double scalar_input2 = -1.0;
        double scalar_input3 = -1.0;
        Ternary.OperationTypes ctableOp = Ternary.findCtableOperationByInputDataTypes(this.input1.getDataType(), this.input2.getDataType(), this.input3.getDataType());
        ctableOp = this._isExpand ? Ternary.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
        MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this.output.getName());
        int brlen = mc1.getRowsPerBlock();
        int bclen = mc1.getColsPerBlock();
        JavaPairRDD inputMBs = null;
        JavaPairRDD ctables = null;
        JavaPairRDD bincellsNoFilter = null;
        boolean setLineage2 = false;
        boolean setLineage3 = false;
        switch (ctableOp) {
            case CTABLE_TRANSFORM: {
                in2 = sec.getBinaryBlockRDDHandleForVariable(this.input2.getName());
                in3 = sec.getBinaryBlockRDDHandleForVariable(this.input3.getName());
                setLineage2 = true;
                setLineage3 = true;
                inputMBs = in1.cogroup(in2).cogroup(in3).mapToPair((PairFunction)new MapThreeMBIterableIntoAL());
                ctables = inputMBs.mapToPair((PairFunction)new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator)this._optr, this._ignoreZeros));
                break;
            }
            case CTABLE_EXPAND_SCALAR_WEIGHT: {
                scalar_input3 = sec.getScalarInput(this.input3.getName(), this.input3.getValueType(), this.input3.isLiteral()).getDoubleValue();
                if (scalar_input3 == 1.0) {
                    in2 = sec.getBinaryBlockRDDHandleForVariable(this.input2.getName());
                    setLineage2 = true;
                    bincellsNoFilter = in2.flatMapToPair((PairFlatMapFunction)new ExpandScalarCtableOperation(brlen));
                    break;
                }
            }
            case CTABLE_TRANSFORM_SCALAR_WEIGHT: {
                in2 = sec.getBinaryBlockRDDHandleForVariable(this.input2.getName());
                setLineage2 = true;
                scalar_input3 = sec.getScalarInput(this.input3.getName(), this.input3.getValueType(), this.input3.isLiteral()).getDoubleValue();
                inputMBs = in1.cogroup(in2).mapToPair((PairFunction)new MapTwoMBIterableIntoAL());
                ctables = inputMBs.mapToPair((PairFunction)new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator)this._optr, this._ignoreZeros));
                break;
            }
            case CTABLE_TRANSFORM_HISTOGRAM: {
                scalar_input2 = sec.getScalarInput(this.input2.getName(), this.input2.getValueType(), this.input2.isLiteral()).getDoubleValue();
                scalar_input3 = sec.getScalarInput(this.input3.getName(), this.input3.getValueType(), this.input3.isLiteral()).getDoubleValue();
                inputMBs = in1.mapToPair((PairFunction)new MapMBIntoAL());
                ctables = inputMBs.mapToPair((PairFunction)new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator)this._optr, this._ignoreZeros));
                break;
            }
            case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: {
                in3 = sec.getBinaryBlockRDDHandleForVariable(this.input3.getName());
                setLineage3 = true;
                scalar_input2 = sec.getScalarInput(this.input2.getName(), this.input2.getValueType(), this.input2.isLiteral()).getDoubleValue();
                inputMBs = in1.cogroup(in3).mapToPair((PairFunction)new MapTwoMBIterableIntoAL());
                ctables = inputMBs.mapToPair((PairFunction)new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator)this._optr, this._ignoreZeros));
                break;
            }
            default: {
                throw new DMLRuntimeException("Encountered an invalid ctable operation (" + (Object)((Object)ctableOp) + ") while executing instruction: " + this.toString());
            }
        }
        if (bincellsNoFilter == null && ctables != null) {
            bincellsNoFilter = ctables.values().flatMapToPair((PairFlatMapFunction)new ExtractBinaryCellsFromCTable());
            bincellsNoFilter = RDDAggregateUtils.sumCellsByKeyStable((JavaPairRDD<MatrixIndexes, Double>)bincellsNoFilter);
        } else if (bincellsNoFilter == null || ctables != null) {
            throw new DMLRuntimeException("Incorrect ctable operation");
        }
        long outputDim1 = this._dim1Literal ? (long)Double.parseDouble(this._outDim1) : sec.getScalarInput(this._outDim1, Expression.ValueType.DOUBLE, false).getLongValue();
        long outputDim2 = this._dim2Literal ? (long)Double.parseDouble(this._outDim2) : sec.getScalarInput(this._outDim2, Expression.ValueType.DOUBLE, false).getLongValue();
        MatrixCharacteristics mcBinaryCells = null;
        boolean bl = findDimensions = outputDim1 == -1L && outputDim2 == -1L;
        if (!findDimensions) {
            if (outputDim1 == -1L && outputDim2 != -1L || outputDim1 != -1L && outputDim2 == -1L) {
                throw new DMLRuntimeException("Incorrect output dimensions passed to TernarySPInstruction:" + outputDim1 + " " + outputDim2);
            }
            mcBinaryCells = new MatrixCharacteristics(outputDim1, outputDim2, brlen, bclen);
            bincellsNoFilter = bincellsNoFilter.filter((Function)new FilterCells(mcBinaryCells.getRows(), mcBinaryCells.getCols()));
        }
        JavaPairRDD<MatrixIndexes, MatrixCell> binaryCells = bincellsNoFilter.mapToPair((PairFunction)new ConvertToBinaryCell());
        if (findDimensions) {
            binaryCells = SparkUtils.cacheBinaryCellRDD(binaryCells);
            mcBinaryCells = SparkUtils.computeMatrixCharacteristics(binaryCells);
        }
        sec.setRDDHandleForVariable(this.output.getName(), binaryCells);
        mcOut.set(mcBinaryCells);
        mcOut.setRowsPerBlock(-1);
        mcOut.setColsPerBlock(-1);
        sec.addLineageRDD(this.output.getName(), this.input1.getName());
        if (setLineage2) {
            sec.addLineageRDD(this.output.getName(), this.input2.getName());
        }
        if (setLineage3) {
            sec.addLineageRDD(this.output.getName(), this.input3.getName());
        }
    }

    private static class FilterCells
    implements Function<Tuple2<MatrixIndexes, Double>, Boolean> {
        private static final long serialVersionUID = 108448577697623247L;
        long rlen;
        long clen;

        public FilterCells(long rlen, long clen) {
            this.rlen = rlen;
            this.clen = clen;
        }

        public Boolean call(Tuple2<MatrixIndexes, Double> kv) throws Exception {
            if (((MatrixIndexes)kv._1).getRowIndex() <= 0L || ((MatrixIndexes)kv._1).getColumnIndex() <= 0L) {
                throw new Exception("Incorrect cell values in TernarySPInstruction:" + kv._1);
            }
            if (((MatrixIndexes)kv._1).getRowIndex() <= this.rlen && ((MatrixIndexes)kv._1).getColumnIndex() <= this.clen) {
                return true;
            }
            return false;
        }
    }

    private static class ConvertToBinaryCell
    implements PairFunction<Tuple2<MatrixIndexes, Double>, MatrixIndexes, MatrixCell> {
        private static final long serialVersionUID = 7481186480851982800L;

        private ConvertToBinaryCell() {
        }

        public Tuple2<MatrixIndexes, MatrixCell> call(Tuple2<MatrixIndexes, Double> kv) throws Exception {
            MatrixCell cell = new MatrixCell((Double)kv._2());
            return new Tuple2(kv._1(), (Object)cell);
        }
    }

    private static class ExtractBinaryCellsFromCTable
    implements PairFlatMapFunction<CTableMap, MatrixIndexes, Double> {
        private static final long serialVersionUID = -5933677686766674444L;

        private ExtractBinaryCellsFromCTable() {
        }

        public Iterator<Tuple2<MatrixIndexes, Double>> call(CTableMap ctableMap) throws Exception {
            ArrayList<Tuple2> retVal = new ArrayList<Tuple2>();
            Iterator<LongLongDoubleHashMap.ADoubleEntry> iter = ctableMap.getIterator();
            while (iter.hasNext()) {
                LongLongDoubleHashMap.ADoubleEntry ijv = iter.next();
                long i = ijv.getKey1();
                long j = ijv.getKey2();
                double v = ijv.value;
                retVal.add(new Tuple2((Object)new MatrixIndexes(i, j), (Object)v));
            }
            return retVal.iterator();
        }
    }

    private static class MapMBIntoAL
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, ArrayList<MatrixBlock>> {
        private static final long serialVersionUID = 2068398913653350125L;

        private MapMBIntoAL() {
        }

        public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> kv) throws Exception {
            ArrayList<Object> retVal = new ArrayList<Object>();
            retVal.add(kv._2);
            return new Tuple2(kv._1, retVal);
        }
    }

    private static class PerformCTableMapSideOperation
    implements PairFunction<Tuple2<MatrixIndexes, ArrayList<MatrixBlock>>, MatrixIndexes, CTableMap> {
        private static final long serialVersionUID = 5348127596473232337L;
        Ternary.OperationTypes ctableOp;
        double scalar_input2;
        double scalar_input3;
        String instString;
        Operator optr;
        boolean ignoreZeros;

        public PerformCTableMapSideOperation(Ternary.OperationTypes ctableOp, double scalar_input2, double scalar_input3, String instString, Operator optr, boolean ignoreZeros) {
            this.ctableOp = ctableOp;
            this.scalar_input2 = scalar_input2;
            this.scalar_input3 = scalar_input3;
            this.instString = instString;
            this.optr = optr;
            this.ignoreZeros = ignoreZeros;
        }

        private static void expectedALSize(int length, ArrayList<MatrixBlock> al) throws Exception {
            if (al.size() != length) {
                throw new Exception("Expected arraylist of size:" + length + ", but found " + al.size());
            }
        }

        public Tuple2<MatrixIndexes, CTableMap> call(Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> kv) throws Exception {
            CTableMap ctableResult = new CTableMap();
            MatrixBlock ctableResultBlock = null;
            IndexedMatrixValue in3 = null;
            IndexedMatrixValue in1 = new IndexedMatrixValue((MatrixIndexes)kv._1, (MatrixValue)((ArrayList)kv._2).get(0));
            MatrixBlock matBlock1 = (MatrixBlock)((ArrayList)kv._2).get(0);
            switch (this.ctableOp) {
                case CTABLE_TRANSFORM: {
                    IndexedMatrixValue in2 = new IndexedMatrixValue((MatrixIndexes)kv._1, (MatrixValue)((ArrayList)kv._2).get(1));
                    in3 = new IndexedMatrixValue((MatrixIndexes)kv._1, (MatrixValue)((ArrayList)kv._2).get(2));
                    PerformCTableMapSideOperation.expectedALSize(3, (ArrayList)kv._2);
                    if (in1 == null || in2 == null || in3 == null) break;
                    OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(), in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, this.optr);
                    break;
                }
                case CTABLE_EXPAND_SCALAR_WEIGHT: 
                case CTABLE_TRANSFORM_SCALAR_WEIGHT: {
                    IndexedMatrixValue in2 = new IndexedMatrixValue((MatrixIndexes)kv._1, (MatrixValue)((ArrayList)kv._2).get(1));
                    PerformCTableMapSideOperation.expectedALSize(2, (ArrayList)kv._2);
                    if (in1 == null || in2 == null) break;
                    matBlock1.ternaryOperations((SimpleOperator)this.optr, (MatrixValue)((ArrayList)kv._2).get(1), this.scalar_input3, this.ignoreZeros, ctableResult, ctableResultBlock);
                    break;
                }
                case CTABLE_TRANSFORM_HISTOGRAM: {
                    PerformCTableMapSideOperation.expectedALSize(1, (ArrayList)kv._2);
                    OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), this.scalar_input2, this.scalar_input3, ctableResult, ctableResultBlock, this.optr);
                    break;
                }
                case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: {
                    PerformCTableMapSideOperation.expectedALSize(2, (ArrayList)kv._2);
                    in3 = new IndexedMatrixValue((MatrixIndexes)kv._1, (MatrixValue)((ArrayList)kv._2).get(1));
                    if (in1 == null || in3 == null) break;
                    OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), this.scalar_input2, in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, this.optr);
                    break;
                }
                default: {
                    throw new DMLRuntimeException("Unrecognized opcode in Tertiary Instruction: " + this.instString);
                }
            }
            return new Tuple2(kv._1, (Object)ctableResult);
        }
    }

    private static class MapThreeMBIterableIntoAL
    implements PairFunction<Tuple2<MatrixIndexes, Tuple2<Iterable<Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>>, Iterable<MatrixBlock>>>, MatrixIndexes, ArrayList<MatrixBlock>> {
        private static final long serialVersionUID = -4873754507037646974L;

        private MapThreeMBIterableIntoAL() {
        }

        private static MatrixBlock extractBlock(Iterable<MatrixBlock> blks, MatrixBlock retVal) throws Exception {
            for (MatrixBlock blk1 : blks) {
                if (retVal != null) {
                    throw new Exception("ERROR: More than 1 matrixblock found for one of the inputs at a given index");
                }
                retVal = blk1;
            }
            if (retVal == null) {
                throw new Exception("ERROR: No matrixblock found for one of the inputs at a given index");
            }
            return retVal;
        }

        public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call(Tuple2<MatrixIndexes, Tuple2<Iterable<Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>>, Iterable<MatrixBlock>>> kv) throws Exception {
            MatrixBlock in1 = null;
            MatrixBlock in2 = null;
            MatrixBlock in3 = null;
            for (Tuple2 blks : (Iterable)((Tuple2)kv._2)._1) {
                in1 = MapThreeMBIterableIntoAL.extractBlock((Iterable)blks._1, in1);
                in2 = MapThreeMBIterableIntoAL.extractBlock((Iterable)blks._2, in2);
            }
            in3 = MapThreeMBIterableIntoAL.extractBlock((Iterable)((Tuple2)kv._2)._2, in3);
            ArrayList<MatrixBlock> inputs = new ArrayList<MatrixBlock>();
            inputs.add(in1);
            inputs.add(in2);
            inputs.add(in3);
            return new Tuple2(kv._1, inputs);
        }
    }

    private static class MapTwoMBIterableIntoAL
    implements PairFunction<Tuple2<MatrixIndexes, Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>>, MatrixIndexes, ArrayList<MatrixBlock>> {
        private static final long serialVersionUID = 271459913267735850L;

        private MapTwoMBIterableIntoAL() {
        }

        private static MatrixBlock extractBlock(Iterable<MatrixBlock> blks, MatrixBlock retVal) throws Exception {
            for (MatrixBlock blk1 : blks) {
                if (retVal != null) {
                    throw new Exception("ERROR: More than 1 matrixblock found for one of the inputs at a given index");
                }
                retVal = blk1;
            }
            if (retVal == null) {
                throw new Exception("ERROR: No matrixblock found for one of the inputs at a given index");
            }
            return retVal;
        }

        public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call(Tuple2<MatrixIndexes, Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>> kv) throws Exception {
            MatrixBlock in1 = null;
            MatrixBlock in2 = null;
            in1 = MapTwoMBIterableIntoAL.extractBlock((Iterable)((Tuple2)kv._2)._1, in1);
            in2 = MapTwoMBIterableIntoAL.extractBlock((Iterable)((Tuple2)kv._2)._2, in2);
            ArrayList<MatrixBlock> inputs = new ArrayList<MatrixBlock>();
            inputs.add(in1);
            inputs.add(in2);
            return new Tuple2(kv._1, inputs);
        }
    }

    private static class ExpandScalarCtableOperation
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, Double> {
        private static final long serialVersionUID = -12552669148928288L;
        private int _brlen;

        public ExpandScalarCtableOperation(int brlen) {
            this._brlen = brlen;
        }

        public Iterator<Tuple2<MatrixIndexes, Double>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ix = (MatrixIndexes)arg0._1();
            MatrixBlock mb = (MatrixBlock)arg0._2();
            ArrayList<Tuple2> retVal = new ArrayList<Tuple2>();
            CTable ctab = CTable.getCTableFnObject();
            for (int i = 0; i < mb.getNumRows(); ++i) {
                double v2;
                long row = UtilFunctions.computeCellIndex(ix.getRowIndex(), this._brlen, i);
                Pair<MatrixIndexes, Double> p = ctab.execute(row, v2 = mb.quickGetValue(i, 0), 1.0);
                if (p.getKey().getRowIndex() < 1L) continue;
                retVal.add(new Tuple2((Object)p.getKey(), (Object)p.getValue()));
            }
            return retVal.iterator();
        }
    }
}

