/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionerSparkAggregator;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionerSparkMapper;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.utils.Statistics;
import scala.Tuple2;

public class ParamservUtils {
    protected static final Log LOG = LogFactory.getLog((String)ParamservUtils.class.getName());
    public static final String PS_FUNC_PREFIX = "_ps_";
    public static long SEED = -1L;

    public static ListObject copyList(ListObject lo, boolean cleanup) {
        List<Data> newData = IntStream.range(0, lo.getLength()).mapToObj(i -> {
            Data oldData = lo.slice(i);
            if (oldData instanceof MatrixObject) {
                return ParamservUtils.createShallowCopy((MatrixObject)oldData);
            }
            if (oldData instanceof ListObject || oldData instanceof FrameObject) {
                throw new DMLRuntimeException("Copy list: does not support list or frame.");
            }
            return oldData;
        }).collect(Collectors.toList());
        ListObject result = new ListObject(newData, lo.getNames());
        if (cleanup) {
            ParamservUtils.cleanupListObject(lo);
        }
        return result;
    }

    public static void cleanupListObject(ExecutionContext ec, String lName) {
        ListObject lo = (ListObject)ec.removeVariable(lName);
        ParamservUtils.cleanupListObject(ec, lo, lo.getStatus());
    }

    public static void cleanupListObject(ExecutionContext ec, String lName, boolean[] status) {
        ListObject lo = (ListObject)ec.removeVariable(lName);
        ParamservUtils.cleanupListObject(ec, lo, status);
    }

    public static void cleanupListObject(ExecutionContext ec, ListObject lo) {
        ParamservUtils.cleanupListObject(ec, lo, lo.getStatus());
    }

    public static void cleanupListObject(ExecutionContext ec, ListObject lo, boolean[] status) {
        for (int i = 0; i < lo.getLength(); ++i) {
            if (status != null && !status[i]) continue;
            ParamservUtils.cleanupData(ec, lo.getData().get(i));
        }
    }

    public static void cleanupData(ExecutionContext ec, Data data) {
        if (!(data instanceof CacheableData)) {
            return;
        }
        CacheableData cd = (CacheableData)data;
        cd.enableCleanup(true);
        ec.cleanupCacheableData(cd);
    }

    public static void cleanupData(ExecutionContext ec, String varName) {
        ParamservUtils.cleanupData(ec, ec.removeVariable(varName));
    }

    public static void cleanupListObject(ListObject lo) {
        ParamservUtils.cleanupListObject(ExecutionContextFactory.createContext(), lo);
    }

    public static MatrixObject newMatrixObject(MatrixBlock mb) {
        return ParamservUtils.newMatrixObject(mb, true);
    }

    public static MatrixObject newMatrixObject(MatrixBlock mb, boolean cleanup) {
        MatrixObject result = new MatrixObject(Types.ValueType.FP64, OptimizerUtils.getUniqueTempFileName(), new MetaDataFormat(new MatrixCharacteristics(-1L, -1L, ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()), Types.FileFormat.BINARY));
        result.acquireModify(mb);
        result.release();
        result.enableCleanup(cleanup);
        return result;
    }

    public static MatrixObject createShallowCopy(MatrixObject mo) {
        return ParamservUtils.newMatrixObject((MatrixBlock)mo.acquireReadAndRelease(), false);
    }

    public static MatrixObject sliceMatrix(MatrixObject mo, long rl, long rh) {
        MatrixBlock mb = (MatrixBlock)mo.acquireReadAndRelease();
        return ParamservUtils.newMatrixObject(ParamservUtils.sliceMatrixBlock(mb, rl, rh), false);
    }

    public static MatrixBlock sliceMatrixBlock(MatrixBlock mb, long rl, long rh) {
        return mb.slice((int)rl - 1, (int)rh - 1);
    }

    public static MatrixBlock generatePermutation(int numEntries, long seed) {
        MatrixBlock seq = new MatrixBlock(numEntries, 1, false);
        MatrixBlock sample = MatrixBlock.sampleOperations(numEntries, numEntries, false, seed);
        return seq.ctableSeqOperations(sample, 1.0, new MatrixBlock(numEntries, numEntries, true));
    }

    public static MatrixBlock generateSubsampleMatrix(int nsamples, int nrows, long seed) {
        MatrixBlock seq = new MatrixBlock(nsamples, nrows, false);
        MatrixBlock sample = MatrixBlock.sampleOperations(nrows, nsamples, false, seed);
        return seq.ctableSeqOperations(sample, 1.0, new MatrixBlock(nsamples, nrows, true), false);
    }

    public static MatrixBlock generateReplicationMatrix(int nsamples, int nrows, long seed) {
        MatrixBlock seq = new MatrixBlock(nsamples, nrows, false);
        MatrixBlock sample = MatrixBlock.sampleOperations(nrows, nsamples, true, seed);
        return seq.ctableSeqOperations(sample, 1.0, new MatrixBlock(nsamples, nrows, true), false);
    }

    public static ExecutionContext createExecutionContext(ExecutionContext ec, LocalVariableMap varsMap, String updFunc, String aggFunc, int k) {
        return ParamservUtils.createExecutionContext(ec, varsMap, updFunc, aggFunc, k, false);
    }

    public static ExecutionContext createExecutionContext(ExecutionContext ec, LocalVariableMap varsMap, String updFunc, String aggFunc, int k, boolean forceExecTypeCP) {
        Program prog = ec.getProgram();
        ParamservUtils.recompileProgramBlocks(k, prog.getProgramBlocks(), forceExecTypeCP);
        boolean opt = prog.getFunctionProgramBlocks(false).isEmpty();
        prog.getFunctionProgramBlocks(opt).forEach((fname, fvalue) -> ParamservUtils.recompileProgramBlocks(k, fvalue.getChildBlocks(), forceExecTypeCP));
        return ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), ParamservUtils.copyProgramFunctions(prog));
    }

    public static List<ExecutionContext> copyExecutionContext(ExecutionContext ec, int num) {
        return IntStream.range(0, num).mapToObj(i -> ExecutionContextFactory.createContext(new LocalVariableMap(ec.getVariables()), ParamservUtils.copyProgramFunctions(ec.getProgram()))).collect(Collectors.toList());
    }

    private static Program copyProgramFunctions(Program prog) {
        Program newProg = new Program(prog.getDMLProg());
        boolean opt = prog.getFunctionProgramBlocks(false).isEmpty();
        for (Map.Entry<String, FunctionProgramBlock> e : prog.getFunctionProgramBlocks(opt).entrySet()) {
            String[] parts = DMLProgram.splitFunctionKey(e.getKey());
            FunctionProgramBlock fpb = ProgramConverter.createDeepCopyFunctionProgramBlock(e.getValue(), new HashSet<String>(), new HashSet<String>());
            newProg.addFunctionProgramBlock(parts[0], parts[1], fpb, opt);
        }
        return newProg;
    }

    public static void recompileProgramBlocks(int k, List<ProgramBlock> pbs) {
        ParamservUtils.recompileProgramBlocks(k, pbs, false);
    }

    public static void recompileProgramBlocks(int k, List<ProgramBlock> pbs, boolean forceExecTypeCP) {
        for (ProgramBlock pb : pbs) {
            DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock());
        }
        try {
            if (forceExecTypeCP) {
                ParamservUtils.rAssignParallelismAndRecompile(pbs, k, true, forceExecTypeCP);
            } else {
                ParamservUtils.rAssignParallelismAndRecompile(pbs, k, false, forceExecTypeCP);
            }
        }
        catch (IOException e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static boolean rAssignParallelismAndRecompile(List<ProgramBlock> pbs, int k, boolean recompiled, boolean forceExecTypeCP) throws IOException {
        for (ProgramBlock pb : pbs) {
            if (pb instanceof ParForProgramBlock) {
                ParForProgramBlock pfpb = (ParForProgramBlock)pb;
                if (!pfpb.isDegreeOfParallelismFixed()) {
                    pfpb.setDegreeOfParallelism(k);
                    if (k == 1) {
                        pfpb.setOptimizationMode(ParForProgramBlock.POptMode.NONE);
                    }
                    recompiled |= ParamservUtils.rAssignParallelismAndRecompile(pfpb.getChildBlocks(), 1, recompiled, forceExecTypeCP);
                }
            } else if (pb instanceof ForProgramBlock) {
                recompiled |= ParamservUtils.rAssignParallelismAndRecompile(((ForProgramBlock)pb).getChildBlocks(), k, recompiled, forceExecTypeCP);
            } else if (pb instanceof WhileProgramBlock) {
                recompiled |= ParamservUtils.rAssignParallelismAndRecompile(((WhileProgramBlock)pb).getChildBlocks(), k, recompiled, forceExecTypeCP);
            } else if (pb instanceof FunctionProgramBlock) {
                recompiled |= ParamservUtils.rAssignParallelismAndRecompile(((FunctionProgramBlock)pb).getChildBlocks(), k, recompiled, forceExecTypeCP);
            } else if (pb instanceof IfProgramBlock) {
                IfProgramBlock ipb = (IfProgramBlock)pb;
                recompiled |= ParamservUtils.rAssignParallelismAndRecompile(ipb.getChildBlocksIfBody(), k, recompiled, forceExecTypeCP);
                if (ipb.getChildBlocksElseBody() != null) {
                    recompiled |= ParamservUtils.rAssignParallelismAndRecompile(ipb.getChildBlocksElseBody(), k, recompiled, forceExecTypeCP);
                }
            } else {
                StatementBlock sb = pb.getStatementBlock();
                for (Hop hop : sb.getHops()) {
                    recompiled |= ParamservUtils.rAssignParallelismAndRecompile(hop, k, recompiled);
                }
            }
            if (!recompiled) continue;
            if (forceExecTypeCP) {
                Recompiler.rRecompileProgramBlock2Forced(pb, pb.getThreadID(), new HashSet<String>(), Types.ExecType.CP);
                continue;
            }
            Recompiler.recompileProgramBlockInstructions(pb);
        }
        return recompiled;
    }

    private static boolean rAssignParallelismAndRecompile(Hop hop, int k, boolean recompiled) {
        if (hop.isVisited()) {
            return recompiled;
        }
        if (hop instanceof MultiThreadedHop) {
            MultiThreadedHop mhop = (MultiThreadedHop)hop;
            mhop.setMaxNumThreads(k);
            recompiled = true;
        }
        ArrayList<Hop> inputs = hop.getInput();
        for (Hop h : inputs) {
            recompiled |= ParamservUtils.rAssignParallelismAndRecompile(h, k, recompiled);
        }
        hop.setVisited();
        return recompiled;
    }

    private static FunctionProgramBlock getFunctionBlock(ExecutionContext ec, String funcName) {
        String[] cfn = DMLProgram.splitFunctionKey(funcName);
        String ns = cfn[0];
        String fname = cfn[1];
        return ec.getProgram().getFunctionProgramBlock(ns, fname);
    }

    public static MatrixBlock cbindMatrix(MatrixBlock left, MatrixBlock right) {
        return left.append(right, new MatrixBlock());
    }

    public static JavaPairRDD<Long, Tuple2<MatrixBlock, MatrixBlock>> assembleTrainingData(JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD, JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD) {
        JavaPairRDD<Long, MatrixBlock> fRDD = ParamservUtils.groupMatrix(featuresRDD);
        JavaPairRDD<Long, MatrixBlock> lRDD = ParamservUtils.groupMatrix(labelsRDD);
        return fRDD.join(lRDD);
    }

    private static JavaPairRDD<Long, MatrixBlock> groupMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> rdd) {
        return rdd.mapToPair((PairFunction & Serializable)input -> new Tuple2((Object)((MatrixIndexes)input._1).getRowIndex(), (Object)new Tuple2((Object)((MatrixIndexes)input._1).getColumnIndex(), input._2))).aggregateByKey(new LinkedList(), (Function2 & Serializable)(list, input) -> {
            list.add(input);
            return list;
        }, (Function2 & Serializable)(l1, l2) -> {
            l1.addAll(l2);
            l1.sort((o1, o2) -> ((Long)o1._1).compareTo((Long)o2._1));
            return l1;
        }).mapToPair((PairFunction & Serializable)input -> {
            LinkedList list = (LinkedList)input._2;
            MatrixBlock result = (MatrixBlock)((Tuple2)list.get((int)0))._2;
            for (int i = 1; i < list.size(); ++i) {
                result = ParamservUtils.cbindMatrix(result, (MatrixBlock)((Tuple2)list.get((int)i))._2);
            }
            return new Tuple2(input._1, (Object)result);
        });
    }

    public static JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> doPartitionOnSpark(SparkExecutionContext sec, MatrixObject features, MatrixObject labels, Statement.PSScheme scheme, final int workerNum) {
        Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
        JavaPairRDD<?, ?> featuresRDD = sec.getRDDHandleForMatrixObject(features, Types.FileFormat.BINARY);
        JavaPairRDD<?, ?> labelsRDD = sec.getRDDHandleForMatrixObject(labels, Types.FileFormat.BINARY);
        DataPartitionerSparkMapper mapper = new DataPartitionerSparkMapper(scheme, workerNum, sec, (int)features.getNumRows());
        JavaPairRDD result = ParamservUtils.assembleTrainingData(featuresRDD, labelsRDD).flatMapToPair((PairFlatMapFunction)mapper).aggregateByKey(new LinkedList(), new Partitioner(){
            private static final long serialVersionUID = -7937781374718031224L;

            public int getPartition(Object workerID) {
                return (Integer)workerID;
            }

            public int numPartitions() {
                return workerNum;
            }
        }, (Function2 & Serializable)(list, input) -> {
            list.add(input);
            return list;
        }, (Function2 & Serializable)(l1, l2) -> {
            l1.addAll(l2);
            l1.sort((o1, o2) -> ((Long)o1._1).compareTo((Long)o2._1));
            return l1;
        }).mapToPair((PairFunction)new DataPartitionerSparkAggregator(features.getNumColumns(), labels.getNumColumns()));
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime((long)tSetup.stop());
        }
        return result;
    }

    public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean cleanup) {
        return ParamservUtils.accrueGradients(accGradients, gradients, false, cleanup);
    }

    public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean par, boolean cleanup) {
        if (accGradients == null) {
            return ParamservUtils.copyList(gradients, cleanup);
        }
        IntStream range = IntStream.range(0, accGradients.getLength());
        (par ? range.parallel() : range).forEach(i -> {
            MatrixBlock mb1 = (MatrixBlock)((MatrixObject)accGradients.getData().get(i)).acquireReadAndRelease();
            MatrixBlock mb2 = (MatrixBlock)((MatrixObject)gradients.getData().get(i)).acquireReadAndRelease();
            mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2);
        });
        if (cleanup) {
            ParamservUtils.cleanupListObject(gradients);
        }
        return accGradients;
    }

    public static ListObject accrueModels(ListObject accModels, ListObject model, boolean cleanup) {
        return ParamservUtils.accrueModels(accModels, model, false, cleanup);
    }

    public static ListObject accrueModels(ListObject accModels, ListObject model, boolean par, boolean cleanup) {
        if (accModels == null) {
            return ParamservUtils.copyList(model, cleanup);
        }
        IntStream range = IntStream.range(0, accModels.getLength());
        (par ? range.parallel() : range).forEach(i -> {
            MatrixBlock mb1 = (MatrixBlock)((MatrixObject)accModels.getData().get(i)).acquireReadAndRelease();
            MatrixBlock mb2 = (MatrixBlock)((MatrixObject)model.getData().get(i)).acquireReadAndRelease();
            mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2);
        });
        if (cleanup) {
            ParamservUtils.cleanupListObject(model);
        }
        return accModels;
    }
}

