/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.yarn.ropt;

import java.io.IOException;
import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.cost.CostEstimationWrapper;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptTreeConverter;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MatrixDimensionsMetaData;
import org.apache.sysml.yarn.DMLYarnClient;
import org.apache.sysml.yarn.ropt.GridEnumeration;
import org.apache.sysml.yarn.ropt.GridEnumerationEqui;
import org.apache.sysml.yarn.ropt.GridEnumerationExp;
import org.apache.sysml.yarn.ropt.GridEnumerationHybrid;
import org.apache.sysml.yarn.ropt.GridEnumerationMemory;
import org.apache.sysml.yarn.ropt.MRJobResourceInstruction;
import org.apache.sysml.yarn.ropt.ResourceConfig;
import org.apache.sysml.yarn.ropt.YarnClusterAnalyzer;
import org.apache.sysml.yarn.ropt.YarnClusterConfig;
import org.apache.sysml.yarn.ropt.YarnOptimizerUtils;

public class ResourceOptimizer {
    private static final Log LOG = LogFactory.getLog(ResourceOptimizer.class);
    public static final long MIN_CP_BUDGET = 0x20000000L;
    public static final boolean INCLUDE_PREDICATES = true;
    public static final boolean PRUNING_SMALL = true;
    public static final boolean PRUNING_UNKNOWN = true;
    public static final boolean COSTS_MAX_PARALLELISM = true;
    public static final boolean COST_INDIVIDUAL_BLOCKS = true;
    private static long _cntCompilePB = 0L;
    private static long _cntCostPB = 0L;

    public static synchronized ResourceConfig optimizeResourceConfig(ArrayList<ProgramBlock> prog, YarnClusterConfig cc, YarnOptimizerUtils.GridEnumType cptype, YarnOptimizerUtils.GridEnumType mrtype) throws DMLRuntimeException {
        ResourceConfig ROpt = null;
        try {
            Timing time = new Timing(true);
            ResourceOptimizer.initStatistics();
            long max = (long)((double)YarnOptimizerUtils.toB(cc.getMaxAllocationMB()) / 1.5);
            long minCP = (long)Math.max((double)YarnOptimizerUtils.toB(cc.getMinAllocationMB()) / 1.5, 5.36870912E8);
            long minMR = YarnOptimizerUtils.computeMinContraint(minCP, max, cc.getAvgNumCores());
            ArrayList<Long> SRc = ResourceOptimizer.enumerateGridPoints(prog, minCP, max, cptype);
            ArrayList<Long> SRm = ResourceOptimizer.enumerateGridPoints(prog, minMR, max, mrtype);
            ROpt = new ResourceConfig(prog, minMR);
            double costOpt = Double.MAX_VALUE;
            for (Long rc : SRc) {
                ArrayList<ProgramBlock> B = ResourceOptimizer.compileProgram(prog, null, (double)rc.longValue(), (double)minMR);
                ArrayList<ProgramBlock> Bp = ResourceOptimizer.pruneProgramBlocks(B);
                LOG.debug((Object)("Enum (rc=" + rc + "): |B|=" + B.size() + ", |Bp|=" + Bp.size()));
                double[][] memo = ResourceOptimizer.initLocalMemoTable(Bp, minMR);
                for (int i = 0; i < Bp.size(); ++i) {
                    ProgramBlock pb = Bp.get(i);
                    for (Long rm : SRm) {
                        ResourceOptimizer.recompileProgramBlock(pb, rc, rm);
                        double lcost = ResourceOptimizer.getProgramCosts(pb);
                        if (!(lcost < memo[i][1])) continue;
                        memo[i][0] = rm.longValue();
                        memo[i][1] = lcost;
                    }
                }
                double[][] gmemo = ResourceOptimizer.initGlobalMemoTable(B, Bp, memo, minMR);
                ResourceOptimizer.recompileProgramBlocks(B, rc, gmemo);
                double gcost = ResourceOptimizer.getProgramCosts(B.get(0).getProgram());
                if (!(gcost < costOpt)) continue;
                ROpt.setCPResource(rc);
                ROpt.setMRResources(B, gmemo);
                costOpt = gcost;
                LOG.debug((Object)("Enum (rc=" + rc + "): found new opt w/ cost=" + gcost));
            }
            LOG.info((Object)"Optimization summary:");
            LOG.info((Object)("-- optimal plan (rc, rm): " + YarnOptimizerUtils.toMB(ROpt.getCPResource()) + "MB, " + YarnOptimizerUtils.toMB(ROpt.getMaxMRResource()) + "MB"));
            LOG.info((Object)("-- costs of optimal plan: " + costOpt));
            LOG.info((Object)("-- # of block compiles:   " + _cntCompilePB));
            LOG.info((Object)("-- # of block costings:   " + _cntCostPB));
            LOG.info((Object)("-- optimization time:     " + String.format("%.3f", time.stop() / 1000.0) + " sec."));
            LOG.info((Object)("-- optimal plan details:  " + ROpt.serialize()));
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        return ROpt;
    }

    public static ArrayList<ProgramBlock> compileProgram(ArrayList<ProgramBlock> prog, ResourceConfig rc) throws DMLRuntimeException, HopsException, LopsException, IOException {
        ArrayList<ProgramBlock> B = ResourceOptimizer.compileProgram(prog, null, (double)rc.getCPResource(), (double)rc.getMaxMRResource());
        ResourceOptimizer.recompileProgramBlocks(B, rc.getCPResource(), rc.getMRResourcesMemo());
        return B;
    }

    private static ArrayList<ProgramBlock> compileProgram(ArrayList<ProgramBlock> prog, ArrayList<ProgramBlock> B, double cp, double mr) throws DMLRuntimeException, HopsException, LopsException, IOException {
        if (B == null) {
            B = new ArrayList();
            InfrastructureAnalyzer.setLocalMaxMemory((long)cp);
            InfrastructureAnalyzer.setRemoteMaxMemoryMap((long)mr);
            InfrastructureAnalyzer.setRemoteMaxMemoryReduce((long)mr);
            OptimizerUtils.resetDefaultSize();
        }
        for (ProgramBlock pb : prog) {
            ResourceOptimizer.compileProgram(pb, B, cp, mr);
        }
        return B;
    }

    private static ArrayList<ProgramBlock> compileProgram(ProgramBlock pb, ArrayList<ProgramBlock> B, double cp, double mr) throws DMLRuntimeException, HopsException, LopsException, IOException {
        if (pb instanceof FunctionProgramBlock) {
            FunctionProgramBlock fpb = (FunctionProgramBlock)pb;
            ResourceOptimizer.compileProgram(fpb.getChildBlocks(), B, cp, mr);
        } else if (pb instanceof WhileProgramBlock) {
            WhileProgramBlock wpb = (WhileProgramBlock)pb;
            WhileStatementBlock sb = (WhileStatementBlock)pb.getStatementBlock();
            if (sb != null && sb.getPredicateHops() != null) {
                ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb.getPredicateHops(), new LocalVariableMap(), null, false, false, 0L);
                wpb.setPredicate(inst);
                B.add(wpb);
                ++_cntCompilePB;
            }
            ResourceOptimizer.compileProgram(wpb.getChildBlocks(), B, cp, mr);
        } else if (pb instanceof IfProgramBlock) {
            IfProgramBlock ipb = (IfProgramBlock)pb;
            IfStatementBlock sb = (IfStatementBlock)ipb.getStatementBlock();
            if (sb != null && sb.getPredicateHops() != null) {
                ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb.getPredicateHops(), new LocalVariableMap(), null, false, false, 0L);
                ipb.setPredicate(inst);
                B.add(ipb);
                ++_cntCompilePB;
            }
            ResourceOptimizer.compileProgram(ipb.getChildBlocksIfBody(), B, cp, mr);
            ResourceOptimizer.compileProgram(ipb.getChildBlocksElseBody(), B, cp, mr);
        } else if (pb instanceof ForProgramBlock) {
            ForProgramBlock fpb = (ForProgramBlock)pb;
            ForStatementBlock sb = (ForStatementBlock)fpb.getStatementBlock();
            if (sb != null) {
                ArrayList<Instruction> inst;
                if (sb.getFromHops() != null) {
                    inst = Recompiler.recompileHopsDag(sb.getFromHops(), new LocalVariableMap(), null, false, false, 0L);
                    fpb.setFromInstructions(inst);
                }
                if (sb.getToHops() != null) {
                    inst = Recompiler.recompileHopsDag(sb.getToHops(), new LocalVariableMap(), null, false, false, 0L);
                    fpb.setToInstructions(inst);
                }
                if (sb.getIncrementHops() != null) {
                    inst = Recompiler.recompileHopsDag(sb.getIncrementHops(), new LocalVariableMap(), null, false, false, 0L);
                    fpb.setIncrementInstructions(inst);
                }
                B.add(fpb);
                ++_cntCompilePB;
            }
            ResourceOptimizer.compileProgram(fpb.getChildBlocks(), B, cp, mr);
        } else {
            StatementBlock sb = pb.getStatementBlock();
            ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb, sb.get_hops(), new LocalVariableMap(), null, false, false, 0L);
            pb.setInstructions(inst);
            B.add(pb);
            ++_cntCompilePB;
        }
        return B;
    }

    private static void recompileProgramBlocks(ArrayList<ProgramBlock> pbs, long cp, double[][] memo) throws DMLRuntimeException, HopsException, LopsException, IOException {
        for (int i = 0; i < pbs.size(); ++i) {
            ProgramBlock pb = pbs.get(i);
            long mr = (long)memo[i][0];
            ResourceOptimizer.recompileProgramBlock(pb, cp, mr);
        }
    }

    private static void recompileProgramBlock(ProgramBlock pb, long cp, long mr) throws DMLRuntimeException, HopsException, LopsException, IOException {
        InfrastructureAnalyzer.setLocalMaxMemory(cp);
        InfrastructureAnalyzer.setRemoteMaxMemoryMap(mr);
        InfrastructureAnalyzer.setRemoteMaxMemoryReduce(mr);
        OptimizerUtils.resetDefaultSize();
        if (pb instanceof WhileProgramBlock) {
            WhileProgramBlock wpb = (WhileProgramBlock)pb;
            WhileStatementBlock sb = (WhileStatementBlock)pb.getStatementBlock();
            if (sb != null && sb.getPredicateHops() != null) {
                ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb.getPredicateHops(), new LocalVariableMap(), null, false, false, 0L);
                inst = ResourceOptimizer.annotateMRJobInstructions(inst, cp, mr);
                wpb.setPredicate(inst);
            }
        } else if (pb instanceof IfProgramBlock) {
            IfProgramBlock ipb = (IfProgramBlock)pb;
            IfStatementBlock sb = (IfStatementBlock)ipb.getStatementBlock();
            if (sb != null && sb.getPredicateHops() != null) {
                ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb.getPredicateHops(), new LocalVariableMap(), null, false, false, 0L);
                inst = ResourceOptimizer.annotateMRJobInstructions(inst, cp, mr);
                ipb.setPredicate(inst);
            }
        } else if (pb instanceof ForProgramBlock) {
            ForProgramBlock fpb = (ForProgramBlock)pb;
            ForStatementBlock sb = (ForStatementBlock)fpb.getStatementBlock();
            if (sb != null) {
                ArrayList<Instruction> inst;
                if (sb.getFromHops() != null) {
                    inst = Recompiler.recompileHopsDag(sb.getFromHops(), new LocalVariableMap(), null, false, false, 0L);
                    inst = ResourceOptimizer.annotateMRJobInstructions(inst, cp, mr);
                    fpb.setFromInstructions(inst);
                }
                if (sb.getToHops() != null) {
                    inst = Recompiler.recompileHopsDag(sb.getToHops(), new LocalVariableMap(), null, false, false, 0L);
                    inst = ResourceOptimizer.annotateMRJobInstructions(inst, cp, mr);
                    fpb.setToInstructions(inst);
                }
                if (sb.getIncrementHops() != null) {
                    inst = Recompiler.recompileHopsDag(sb.getIncrementHops(), new LocalVariableMap(), null, false, false, 0L);
                    inst = ResourceOptimizer.annotateMRJobInstructions(inst, cp, mr);
                    fpb.setIncrementInstructions(inst);
                }
            }
        } else {
            StatementBlock sb = pb.getStatementBlock();
            ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb, sb.get_hops(), new LocalVariableMap(), null, false, false, 0L);
            inst = ResourceOptimizer.annotateMRJobInstructions(inst, cp, mr);
            pb.setInstructions(inst);
        }
        ++_cntCompilePB;
    }

    private static ArrayList<Instruction> annotateMRJobInstructions(ArrayList<Instruction> inst, long cp, long mr) throws DMLRuntimeException {
        if (inst == null) {
            return inst;
        }
        try {
            for (int i = 0; i < inst.size(); ++i) {
                Instruction linst = inst.get(i);
                if (!(linst instanceof MRJobInstruction)) continue;
                MRJobResourceInstruction newlinst = new MRJobResourceInstruction((MRJobInstruction)linst);
                long maxMemPerNode = YarnClusterAnalyzer.getMaxAllocationBytes();
                long nNodes = YarnClusterAnalyzer.getNumNodes();
                long totalMem = nNodes * maxMemPerNode;
                long maxMRTasks = (totalMem - DMLYarnClient.computeMemoryAllocation(cp)) / DMLYarnClient.computeMemoryAllocation(mr);
                newlinst.setMaxMRTasks(maxMRTasks);
                inst.set(i, newlinst);
            }
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        return inst;
    }

    private static double getProgramCosts(ProgramBlock pb) throws DMLRuntimeException, HopsException {
        double val = 0.0;
        LocalVariableMap vars = new LocalVariableMap();
        ResourceOptimizer.collectReadVariables(pb.getStatementBlock().get_hops(), vars);
        ExecutionContext ec = ExecutionContextFactory.createContext(false, null);
        ec.setVariables(vars);
        val = CostEstimationWrapper.getTimeEstimate(pb, ec, false);
        ++_cntCostPB;
        return val;
    }

    private static double getProgramCosts(Program prog) throws DMLRuntimeException {
        ExecutionContext ec = ExecutionContextFactory.createContext();
        double val = CostEstimationWrapper.getTimeEstimate(prog, ec);
        ++_cntCostPB;
        return val;
    }

    private static void collectReadVariables(ArrayList<Hop> hops, LocalVariableMap vars) {
        if (hops != null) {
            Hop.resetVisitStatus(hops);
            for (Hop hop : hops) {
                ResourceOptimizer.collectReadVariables(hop, vars);
            }
        }
    }

    private static void collectReadVariables(Hop hop, LocalVariableMap vars) {
        if (hop == null) {
            return;
        }
        for (Hop hi : hop.getInput()) {
            ResourceOptimizer.collectReadVariables(hi, vars);
        }
        if (hop instanceof DataOp && hop.getDataType() == Expression.DataType.MATRIX && (((DataOp)hop).getDataOpType() == Hop.DataOpTypes.TRANSIENTREAD || ((DataOp)hop).getDataOpType() == Hop.DataOpTypes.PERSISTENTREAD)) {
            String varname = hop.getName();
            MatrixCharacteristics mc = new MatrixCharacteristics(hop.getDim1(), hop.getDim2(), (int)hop.getRowsInBlock(), (int)hop.getColsInBlock(), hop.getNnz());
            MatrixDimensionsMetaData md = new MatrixDimensionsMetaData(mc);
            MatrixObject mo = new MatrixObject(Expression.ValueType.DOUBLE, "/tmp", md);
            vars.put(varname, mo);
        }
        hop.setVisited();
    }

    private static ArrayList<ProgramBlock> pruneProgramBlocks(ArrayList<ProgramBlock> B) throws HopsException {
        ArrayList<ProgramBlock> Bp = new ArrayList<ProgramBlock>();
        for (ProgramBlock pb : B) {
            if (!OptTreeConverter.containsMRJobInstruction(pb.getInstructions(), false, true)) continue;
            Bp.add(pb);
        }
        B = Bp;
        Bp = new ArrayList();
        for (ProgramBlock pb : B) {
            if (ResourceOptimizer.pruneHasOnlyUnknownMR(pb)) continue;
            Bp.add(pb);
        }
        B = Bp;
        return B;
    }

    private static boolean pruneHasOnlyUnknownMR(ProgramBlock pb) throws HopsException {
        if (pb instanceof WhileProgramBlock) {
            WhileStatementBlock sb = (WhileStatementBlock)pb.getStatementBlock();
            sb.getPredicateHops().resetVisitStatus();
            return ResourceOptimizer.pruneHasOnlyUnknownMR(sb.getPredicateHops());
        }
        if (pb instanceof IfProgramBlock) {
            IfStatementBlock sb = (IfStatementBlock)pb.getStatementBlock();
            sb.getPredicateHops().resetVisitStatus();
            return ResourceOptimizer.pruneHasOnlyUnknownMR(sb.getPredicateHops());
        }
        if (pb instanceof ForProgramBlock) {
            ForStatementBlock sb = (ForStatementBlock)pb.getStatementBlock();
            sb.getFromHops().resetVisitStatus();
            sb.getToHops().resetVisitStatus();
            sb.getIncrementHops().resetVisitStatus();
            return ResourceOptimizer.pruneHasOnlyUnknownMR(sb.getFromHops()) && ResourceOptimizer.pruneHasOnlyUnknownMR(sb.getToHops()) && ResourceOptimizer.pruneHasOnlyUnknownMR(sb.getIncrementHops());
        }
        StatementBlock sb = pb.getStatementBlock();
        return ResourceOptimizer.pruneHasOnlyUnknownMR(sb.get_hops());
    }

    private static boolean pruneHasOnlyUnknownMR(ArrayList<Hop> hops) throws HopsException {
        boolean ret = false;
        if (hops != null) {
            ret = true;
            Hop.resetVisitStatus(hops);
            for (Hop hop : hops) {
                ret &= ResourceOptimizer.pruneHasOnlyUnknownMR(hop);
            }
        }
        return ret;
    }

    private static boolean pruneHasOnlyUnknownMR(Hop hop) {
        if (hop == null || hop.isVisited()) {
            return true;
        }
        boolean ret = true;
        for (Hop hi : hop.getInput()) {
            ret &= ResourceOptimizer.pruneHasOnlyUnknownMR(hi);
        }
        if (hop.getExecType() == LopProperties.ExecType.MR) {
            boolean lret = false;
            lret |= !hop.dimsKnown();
            for (Hop hi : hop.getInput()) {
                lret |= !hi.dimsKnown();
            }
            ret &= lret;
        }
        hop.setVisited();
        return ret;
    }

    private static ArrayList<Long> enumerateGridPoints(ArrayList<ProgramBlock> prog, long min, long max, YarnOptimizerUtils.GridEnumType type) throws DMLRuntimeException, HopsException {
        GridEnumeration ge = null;
        switch (type) {
            case EQUI_GRID: {
                ge = new GridEnumerationEqui(prog, min, max);
                break;
            }
            case EXP_GRID: {
                ge = new GridEnumerationExp(prog, min, max);
                break;
            }
            case MEM_EQUI_GRID: {
                ge = new GridEnumerationMemory(prog, min, max);
                break;
            }
            case HYBRID_MEM_EXP_GRID: {
                ge = new GridEnumerationHybrid(prog, min, max);
                break;
            }
            default: {
                throw new DMLRuntimeException("Unsupported grid enumeration type: " + (Object)((Object)type));
            }
        }
        ArrayList<Long> ret = ge.enumerateGridPoints();
        LOG.debug((Object)("Gen: min=" + YarnOptimizerUtils.toMB(min) + ", max=" + YarnOptimizerUtils.toMB(max) + ", npoints=" + ret.size()));
        return ret;
    }

    private static double[][] initLocalMemoTable(ArrayList<ProgramBlock> Bp, double min) throws DMLRuntimeException {
        int len = Bp.size();
        double[][] memo = new double[len][2];
        for (int i = 0; i < len; ++i) {
            ProgramBlock pb = Bp.get(i);
            ExecutionContext ec = ExecutionContextFactory.createContext();
            memo[i][0] = min;
            memo[i][1] = CostEstimationWrapper.getTimeEstimate(pb.getProgram(), ec);
        }
        return memo;
    }

    private static double[][] initGlobalMemoTable(ArrayList<ProgramBlock> B, ArrayList<ProgramBlock> Bp, double[][] lmemo, double min) {
        int len = B.size();
        int lenp = Bp.size();
        double[][] memo = new double[len][2];
        for (int i = 0; i < len; ++i) {
            memo[i][0] = min;
            memo[i][1] = -1.0;
        }
        int j = 0;
        for (int i = 0; i < len && j < lenp; ++i) {
            ProgramBlock pb = B.get(i);
            if (pb != Bp.get(j)) continue;
            memo[i][0] = lmemo[j][0];
            memo[i][1] = -1.0;
            ++j;
        }
        return memo;
    }

    public static void initStatistics() {
        _cntCompilePB = 0L;
        _cntCostPB = 0L;
    }

    public static long jvmToPhy(long jvm, boolean mrRealRun) {
        long lowerBound;
        long ret = (long)Math.ceil((double)jvm * 1.5);
        if (mrRealRun && ret < (lowerBound = (long)YarnClusterAnalyzer.getMinMRContarinerPhyMB() * 1024L * 1024L)) {
            return lowerBound;
        }
        return ret;
    }

    public static long budgetToJvm(double budget) {
        return (long)Math.ceil(budget / OptimizerUtils.MEM_UTIL_FACTOR);
    }

    public static double phyToBudget(long physical) throws IOException {
        return (double)physical / 1.5 * OptimizerUtils.MEM_UTIL_FACTOR;
    }
}

