/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.AccumulatorV2;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysml.runtime.instructions.spark.ParameterizedBuiltinSPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.io.FrameReader;
import org.apache.sysml.runtime.io.FrameReaderFactory;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.transform.encode.Encoder;
import org.apache.sysml.runtime.transform.encode.EncoderComposite;
import org.apache.sysml.runtime.transform.encode.EncoderFactory;
import org.apache.sysml.runtime.transform.encode.EncoderMVImpute;
import org.apache.sysml.runtime.transform.encode.EncoderRecode;
import org.apache.sysml.runtime.transform.meta.TfMetaUtils;
import org.apache.sysml.runtime.transform.meta.TfOffsetMap;
import scala.Tuple2;

public class MultiReturnParameterizedBuiltinSPInstruction
extends ComputationSPInstruction {
    protected ArrayList<CPOperand> _outputs;

    private MultiReturnParameterizedBuiltinSPInstruction(Operator op, CPOperand input1, CPOperand input2, ArrayList<CPOperand> outputs, String opcode, String istr) {
        super(op, input1, input2, outputs.get(0), opcode, istr);
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.MultiReturnBuiltin;
        this._outputs = outputs;
    }

    public static MultiReturnParameterizedBuiltinSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        ArrayList<CPOperand> outputs = new ArrayList<CPOperand>();
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("transformencode")) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            outputs.add(new CPOperand(parts[3], Expression.ValueType.DOUBLE, Expression.DataType.MATRIX));
            outputs.add(new CPOperand(parts[4], Expression.ValueType.STRING, Expression.DataType.FRAME));
            return new MultiReturnParameterizedBuiltinSPInstruction(null, in1, in2, outputs, opcode, str);
        }
        throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        try {
            FrameObject fo = sec.getFrameObject(this.input1.getName());
            FrameObject fometa = sec.getFrameObject(this._outputs.get(1).getName());
            JavaPairRDD<?, ?> in = sec.getRDDHandleForFrameObject(fo, InputInfo.BinaryBlockInputInfo);
            String spec = ec.getScalarInput(this.input2.getName(), this.input2.getValueType(), this.input2.isLiteral()).getStringValue();
            MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(this.input1.getName());
            MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this.output.getName());
            String[] colnames = !TfMetaUtils.isIDSpec(spec) ? ((FrameBlock)in.lookup((Object)1L).get(0)).getColumnNames() : null;
            Encoder encoderBuild = EncoderFactory.createEncoder(spec, colnames, fo.getSchema(), (int)fo.getNumColumns(), null);
            MaxLongAccumulator accMax = MultiReturnParameterizedBuiltinSPInstruction.registerMaxLongAccumulator(sec.getSparkContext());
            JavaRDD rcMaps = in.mapPartitionsToPair((PairFlatMapFunction)new TransformEncodeBuildFunction(encoderBuild)).distinct().groupByKey().flatMap((FlatMapFunction)new TransformEncodeGroupFunction(accMax));
            if (MultiReturnParameterizedBuiltinSPInstruction.containsMVImputeEncoder(encoderBuild)) {
                EncoderMVImpute mva = MultiReturnParameterizedBuiltinSPInstruction.getMVImputeEncoder(encoderBuild);
                rcMaps = rcMaps.union(in.mapPartitionsToPair((PairFlatMapFunction)new TransformEncodeBuild2Function(mva)).groupByKey().flatMap((FlatMapFunction)new TransformEncodeGroup2Function(mva)));
            }
            rcMaps.saveAsTextFile(fometa.getFileName());
            FrameReader reader = FrameReaderFactory.createFrameReader(InputInfo.TextCellInputInfo);
            FrameBlock meta = reader.readFrameFromHDFS(fometa.getFileName(), accMax.value(), fo.getNumColumns());
            meta.recomputeColumnCardinality();
            meta.setColumnNames(colnames != null ? colnames : meta.getColumnNames());
            TfOffsetMap omap = null;
            if (TfMetaUtils.containsOmitSpec(spec, colnames)) {
                omap = new TfOffsetMap(SparkUtils.toIndexedLong(in.mapToPair((PairFunction)new ParameterizedBuiltinSPInstruction.RDDTransformApplyOffsetFunction(spec, colnames)).collect()));
            }
            Encoder encoder = EncoderFactory.createEncoder(spec, colnames, fo.getSchema(), (int)fo.getNumColumns(), meta);
            mcOut.setDimension(mcIn.getRows() - (omap != null ? omap.getNumRmRows() : 0L), encoder.getNumCols());
            Broadcast bmeta = sec.getSparkContext().broadcast((Object)encoder);
            Broadcast bomap = omap != null ? sec.getSparkContext().broadcast((Object)omap) : null;
            JavaPairRDD tmp = in.mapToPair((PairFunction)new ParameterizedBuiltinSPInstruction.RDDTransformApplyFunction((Broadcast<Encoder>)bmeta, (Broadcast<TfOffsetMap>)bomap));
            JavaPairRDD<MatrixIndexes, MatrixBlock> out = FrameRDDConverterUtils.binaryBlockToMatrixBlock((JavaPairRDD<Long, FrameBlock>)tmp, mcOut, mcOut);
            sec.setRDDHandleForVariable(this._outputs.get(0).getName(), out);
            sec.addLineageRDD(this._outputs.get(0).getName(), this.input1.getName());
            sec.setFrameOutput(this._outputs.get(1).getName(), meta);
        }
        catch (IOException ex) {
            throw new RuntimeException(ex);
        }
    }

    private static boolean containsMVImputeEncoder(Encoder encoder) {
        if (encoder instanceof EncoderComposite) {
            for (Encoder cencoder : ((EncoderComposite)encoder).getEncoders()) {
                if (!(cencoder instanceof EncoderMVImpute)) continue;
                return true;
            }
        }
        return false;
    }

    private static EncoderMVImpute getMVImputeEncoder(Encoder encoder) {
        if (encoder instanceof EncoderComposite) {
            for (Encoder cencoder : ((EncoderComposite)encoder).getEncoders()) {
                if (!(cencoder instanceof EncoderMVImpute)) continue;
                return (EncoderMVImpute)cencoder;
            }
        }
        return null;
    }

    private static MaxLongAccumulator registerMaxLongAccumulator(JavaSparkContext sc) {
        MaxLongAccumulator acc = new MaxLongAccumulator(Long.MIN_VALUE);
        sc.sc().register((AccumulatorV2)acc, "max");
        return acc;
    }

    public static class TransformEncodeGroup2Function
    implements FlatMapFunction<Tuple2<Integer, Iterable<FrameBlock.ColumnMetadata>>, String> {
        private static final long serialVersionUID = 702100641492347459L;
        private EncoderMVImpute _encoder = null;

        public TransformEncodeGroup2Function(EncoderMVImpute encoder) {
            this._encoder = encoder;
        }

        public Iterator<String> call(Tuple2<Integer, Iterable<FrameBlock.ColumnMetadata>> arg0) throws Exception {
            int colix = (Integer)arg0._1();
            Iterator iter = ((Iterable)arg0._2()).iterator();
            ArrayList<String> ret = new ArrayList<String>();
            if (this._encoder.getMethod(colix) == EncoderMVImpute.MVMethod.GLOBAL_MODE) {
                HashMap<String, Long> hist = new HashMap<String, Long>();
                while (iter.hasNext()) {
                    FrameBlock.ColumnMetadata cmeta = (FrameBlock.ColumnMetadata)iter.next();
                    Long tmp = (Long)hist.get(cmeta.getMvValue());
                    hist.put(cmeta.getMvValue(), cmeta.getNumDistinct() + (tmp != null ? tmp : 0L));
                }
                long max = Long.MIN_VALUE;
                String mode = null;
                for (Map.Entry e : hist.entrySet()) {
                    if ((Long)e.getValue() <= max) continue;
                    mode = (String)e.getKey();
                    max = (Long)e.getValue();
                }
                ret.add("-2 " + colix + " " + mode);
            } else if (this._encoder.getMethod(colix) == EncoderMVImpute.MVMethod.GLOBAL_MEAN) {
                KahanObject kbuff = new KahanObject(0.0, 0.0);
                KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
                int count = 0;
                while (iter.hasNext()) {
                    FrameBlock.ColumnMetadata cmeta = (FrameBlock.ColumnMetadata)iter.next();
                    kplus.execute2(kbuff, Double.parseDouble(cmeta.getMvValue()));
                    count = (int)((long)count + cmeta.getNumDistinct());
                }
                if (count > 0) {
                    ret.add("-2 " + colix + " " + String.valueOf(kbuff._sum / (double)count));
                }
            } else if (this._encoder.getMethod(colix) == EncoderMVImpute.MVMethod.CONSTANT && iter.hasNext()) {
                ret.add("-2 " + colix + " " + ((FrameBlock.ColumnMetadata)iter.next()).getMvValue());
            }
            return ret.iterator();
        }
    }

    public static class TransformEncodeBuild2Function
    implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Integer, FrameBlock.ColumnMetadata> {
        private static final long serialVersionUID = 6336375833412029279L;
        private EncoderMVImpute _encoder = null;

        public TransformEncodeBuild2Function(EncoderMVImpute encoder) {
            this._encoder = encoder;
        }

        public Iterator<Tuple2<Integer, FrameBlock.ColumnMetadata>> call(Iterator<Tuple2<Long, FrameBlock>> iter) throws Exception {
            while (iter.hasNext()) {
                FrameBlock block = (FrameBlock)iter.next()._2();
                this._encoder.build(block);
            }
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            int[] collist = this._encoder.getColList();
            for (int j = 0; j < collist.length; ++j) {
                if (this._encoder.getMethod(collist[j]) == EncoderMVImpute.MVMethod.GLOBAL_MODE) {
                    HashMap<String, Long> hist = this._encoder.getHistogram(collist[j]);
                    for (Map.Entry<String, Long> e : hist.entrySet()) {
                        ret.add(new Tuple2((Object)collist[j], (Object)new FrameBlock.ColumnMetadata(e.getValue(), e.getKey())));
                    }
                    continue;
                }
                if (this._encoder.getMethod(collist[j]) == EncoderMVImpute.MVMethod.GLOBAL_MEAN) {
                    ret.add(new Tuple2((Object)collist[j], (Object)new FrameBlock.ColumnMetadata(this._encoder.getNonMVCount(collist[j]), String.valueOf(this._encoder.getMeans()[j]._sum))));
                    continue;
                }
                if (this._encoder.getMethod(collist[j]) != EncoderMVImpute.MVMethod.CONSTANT) continue;
                ret.add(new Tuple2((Object)collist[j], (Object)new FrameBlock.ColumnMetadata(0L, this._encoder.getReplacement(collist[j]))));
            }
            return ret.iterator();
        }
    }

    public static class TransformEncodeGroupFunction
    implements FlatMapFunction<Tuple2<Integer, Iterable<Object>>, String> {
        private static final long serialVersionUID = -1034187226023517119L;
        private MaxLongAccumulator _accMax = null;

        public TransformEncodeGroupFunction(MaxLongAccumulator accMax) {
            this._accMax = accMax;
        }

        public Iterator<String> call(Tuple2<Integer, Iterable<Object>> arg0) throws Exception {
            String colID = String.valueOf(arg0._1());
            Iterator iter = ((Iterable)arg0._2()).iterator();
            ArrayList<String> ret = new ArrayList<String>();
            StringBuilder sb = new StringBuilder();
            long rowID = 1L;
            while (iter.hasNext()) {
                sb.append(rowID);
                sb.append(' ');
                sb.append(colID);
                sb.append(' ');
                sb.append(EncoderRecode.constructRecodeMapEntry(iter.next().toString(), rowID));
                ret.add(sb.toString());
                sb.setLength(0);
                ++rowID;
            }
            this._accMax.add(rowID - 1L);
            return ret.iterator();
        }
    }

    public static class TransformEncodeBuildFunction
    implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Integer, Object> {
        private static final long serialVersionUID = 6336375833412029279L;
        private EncoderRecode _raEncoder = null;

        public TransformEncodeBuildFunction(Encoder encoder) {
            for (Encoder cEncoder : ((EncoderComposite)encoder).getEncoders()) {
                if (!(cEncoder instanceof EncoderRecode)) continue;
                this._raEncoder = (EncoderRecode)cEncoder;
            }
        }

        public Iterator<Tuple2<Integer, Object>> call(Iterator<Tuple2<Long, FrameBlock>> iter) throws Exception {
            while (iter.hasNext()) {
                this._raEncoder.buildPartial((FrameBlock)iter.next()._2());
            }
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            HashMap<Integer, HashSet<Object>> tmp = this._raEncoder.getCPRecodeMapsPartial();
            for (Map.Entry<Integer, HashSet<Object>> e1 : tmp.entrySet()) {
                for (Object token : e1.getValue()) {
                    ret.add(new Tuple2((Object)e1.getKey(), token));
                }
            }
            this._raEncoder.getCPRecodeMapsPartial().clear();
            return ret.iterator();
        }
    }

    private static class MaxLongAccumulator
    extends AccumulatorV2<Long, Long> {
        private static final long serialVersionUID = -3739727823287550826L;
        private long _value = Long.MIN_VALUE;

        public MaxLongAccumulator(long value) {
            this._value = value;
        }

        public void add(Long arg0) {
            this._value = Math.max(this._value, arg0);
        }

        public AccumulatorV2<Long, Long> copy() {
            return new MaxLongAccumulator(this._value);
        }

        public boolean isZero() {
            return this._value == Long.MIN_VALUE;
        }

        public void merge(AccumulatorV2<Long, Long> arg0) {
            this._value = Math.max(this._value, (Long)arg0.value());
        }

        public void reset() {
            this._value = Long.MIN_VALUE;
        }

        public Long value() {
            return this._value;
        }
    }
}

