/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.cost;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.StringTokenizer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.cost.VarStats;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ExternalFunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.MRInstructionParser;
import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.instructions.cp.AggregateTernaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.DataGenCPInstruction;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysml.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysml.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction;
import org.apache.sysml.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.util.UtilFunctions;

public abstract class CostEstimator {
    protected static final Log LOG = LogFactory.getLog(CostEstimator.class.getName());
    private static final int DEFAULT_NUMITER = 15;
    protected static final VarStats _unknownStats = new VarStats(1L, 1L, -1L, -1L, -1L, false);
    protected static final VarStats _scalarStats = new VarStats(1L, 1L, 1L, 1L, 1L, true);

    public double getTimeEstimate(Program rtprog, LocalVariableMap vars, HashMap<String, VarStats> stats) throws DMLRuntimeException {
        double costs = 0.0;
        CostEstimator.maintainVariableStatistics(vars, stats);
        for (ProgramBlock pb : rtprog.getProgramBlocks()) {
            costs += this.rGetTimeEstimate(pb, stats, new HashSet<String>(), true);
        }
        return costs;
    }

    public double getTimeEstimate(ProgramBlock pb, LocalVariableMap vars, HashMap<String, VarStats> stats, boolean recursive) throws DMLRuntimeException {
        CostEstimator.maintainVariableStatistics(vars, stats);
        return this.rGetTimeEstimate(pb, stats, new HashSet<String>(), recursive);
    }

    private double rGetTimeEstimate(ProgramBlock pb, HashMap<String, VarStats> stats, HashSet<String> memoFunc, boolean recursive) throws DMLRuntimeException {
        double ret;
        block13: {
            block16: {
                block15: {
                    block14: {
                        block12: {
                            ret = 0.0;
                            if (!(pb instanceof WhileProgramBlock)) break block12;
                            WhileProgramBlock tmp = (WhileProgramBlock)pb;
                            if (recursive) {
                                for (ProgramBlock pb2 : tmp.getChildBlocks()) {
                                    ret += this.rGetTimeEstimate(pb2, stats, memoFunc, recursive);
                                }
                            }
                            ret *= 15.0;
                            break block13;
                        }
                        if (!(pb instanceof IfProgramBlock)) break block14;
                        IfProgramBlock tmp = (IfProgramBlock)pb;
                        if (!recursive) break block13;
                        for (ProgramBlock pb2 : tmp.getChildBlocksIfBody()) {
                            ret += this.rGetTimeEstimate(pb2, stats, memoFunc, recursive);
                        }
                        if (tmp.getChildBlocksElseBody() == null) break block13;
                        for (ProgramBlock pb2 : tmp.getChildBlocksElseBody()) {
                            ret += this.rGetTimeEstimate(pb2, stats, memoFunc, recursive);
                            ret /= 2.0;
                        }
                        break block13;
                    }
                    if (!(pb instanceof ForProgramBlock)) break block15;
                    ForProgramBlock tmp = (ForProgramBlock)pb;
                    if (recursive) {
                        for (ProgramBlock pb2 : tmp.getChildBlocks()) {
                            ret += this.rGetTimeEstimate(pb2, stats, memoFunc, recursive);
                        }
                    }
                    ret *= (double)CostEstimator.getNumIterations(stats, tmp);
                    break block13;
                }
                if (!(pb instanceof FunctionProgramBlock) || pb instanceof ExternalFunctionProgramBlock) break block16;
                FunctionProgramBlock tmp = (FunctionProgramBlock)pb;
                if (!recursive) break block13;
                for (ProgramBlock pb2 : tmp.getChildBlocks()) {
                    ret += this.rGetTimeEstimate(pb2, stats, memoFunc, recursive);
                }
                break block13;
            }
            ArrayList<Instruction> tmp = pb.getInstructions();
            for (Instruction inst : tmp) {
                VarStats[] vs;
                Object[] o;
                if (inst instanceof CPInstruction) {
                    FunctionCallCPInstruction finst;
                    String fkey;
                    CostEstimator.maintainCPInstVariableStatistics((CPInstruction)inst, stats);
                    o = CostEstimator.extractCPInstStatistics(inst, stats);
                    vs = (VarStats[])o[0];
                    String[] attr = (String[])o[1];
                    ret += this.getCPInstTimeEstimate(inst, vs, attr);
                    if (!(inst instanceof FunctionCallCPInstruction) || memoFunc.contains(fkey = DMLProgram.constructFunctionKey((finst = (FunctionCallCPInstruction)inst).getNamespace(), finst.getFunctionName())) || pb.getProgram() == null) continue;
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("Begin Function " + fkey);
                    }
                    memoFunc.add(fkey);
                    Program prog = pb.getProgram();
                    FunctionProgramBlock fpb = prog.getFunctionProgramBlock(finst.getNamespace(), finst.getFunctionName());
                    ret += this.rGetTimeEstimate(fpb, stats, memoFunc, recursive);
                    memoFunc.remove(fkey);
                    if (!LOG.isDebugEnabled()) continue;
                    LOG.debug("End Function " + fkey);
                    continue;
                }
                if (!(inst instanceof MRJobInstruction)) continue;
                this.maintainMRJobInstVariableStatistics(inst, stats);
                o = CostEstimator.extractMRJobInstStatistics(inst, stats);
                vs = (VarStats[])o[0];
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Begin MRJob type=" + (Object)((Object)((MRJobInstruction)inst).getJobType()));
                }
                ret += this.getMRJobInstTimeEstimate(inst, vs, null);
                if (LOG.isDebugEnabled()) {
                    LOG.debug("End MRJob");
                }
                CostEstimator.cleanupMRJobVariableStatistics(inst, stats);
            }
        }
        return ret;
    }

    private static void maintainVariableStatistics(LocalVariableMap vars, HashMap<String, VarStats> stats) throws DMLRuntimeException {
        for (String varname : vars.keySet()) {
            Data dat = vars.get(varname);
            VarStats vs = null;
            if (dat instanceof MatrixObject) {
                MatrixObject mo = (MatrixObject)dat;
                MatrixCharacteristics mc = mo.getMatrixCharacteristics();
                long rlen = mc.getRows();
                long clen = mc.getCols();
                long brlen = mc.getRowsPerBlock();
                long bclen = mc.getColsPerBlock();
                long nnz = mc.getNonZeros();
                boolean inmem = mo.getStatusAsString().equals("CACHED");
                vs = new VarStats(rlen, clen, brlen, bclen, nnz, inmem);
            } else {
                vs = _scalarStats;
            }
            stats.put(varname, vs);
        }
    }

    private static void maintainCPInstVariableStatistics(CPInstruction inst, HashMap<String, VarStats> stats) {
        block9: {
            block7: {
                String[] parts;
                String optype;
                block11: {
                    block10: {
                        block8: {
                            if (!(inst instanceof VariableCPInstruction)) break block7;
                            optype = inst.getOpcode();
                            parts = InstructionUtils.getInstructionParts(inst.toString());
                            if (!optype.equals("createvar")) break block8;
                            if (parts.length < 10) {
                                return;
                            }
                            String varname = parts[1];
                            long rlen = Long.parseLong(parts[6]);
                            long clen = Long.parseLong(parts[7]);
                            long brlen = Long.parseLong(parts[8]);
                            long bclen = Long.parseLong(parts[9]);
                            long nnz = Long.parseLong(parts[10]);
                            VarStats vs = new VarStats(rlen, clen, brlen, bclen, nnz, false);
                            stats.put(varname, vs);
                            break block9;
                        }
                        if (!optype.equals("cpvar")) break block10;
                        String varname = parts[1];
                        String varname2 = parts[2];
                        VarStats vs = stats.get(varname);
                        stats.put(varname2, vs);
                        break block9;
                    }
                    if (!optype.equals("mvvar")) break block11;
                    String varname = parts[1];
                    String varname2 = parts[2];
                    VarStats vs = stats.remove(varname);
                    stats.put(varname2, vs);
                    break block9;
                }
                if (!optype.equals("rmvar")) break block9;
                String varname = parts[1];
                stats.remove(varname);
                break block9;
            }
            if (inst instanceof DataGenCPInstruction) {
                DataGenCPInstruction randInst = (DataGenCPInstruction)inst;
                String varname = randInst.output.getName();
                long rlen = randInst.getRows();
                long clen = randInst.getCols();
                long brlen = randInst.getRowsInBlock();
                long bclen = randInst.getColsInBlock();
                long nnz = (long)(randInst.getSparsity() * (double)rlen * (double)clen);
                VarStats vs = new VarStats(rlen, clen, brlen, bclen, nnz, true);
                stats.put(varname, vs);
            } else if (inst instanceof StringInitCPInstruction) {
                StringInitCPInstruction iinst = (StringInitCPInstruction)inst;
                String varname = iinst.output.getName();
                long rlen = iinst.getRows();
                long clen = iinst.getCols();
                VarStats vs = new VarStats(rlen, clen, ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), rlen * clen, true);
                stats.put(varname, vs);
            } else if (inst instanceof FunctionCallCPInstruction) {
                FunctionCallCPInstruction finst = (FunctionCallCPInstruction)inst;
                ArrayList<String> outVars = finst.getBoundOutputParamNames();
                for (String varname : outVars) {
                    stats.put(varname, _unknownStats);
                }
            }
        }
    }

    /*
     * WARNING - void declaration
     */
    private void maintainMRJobInstVariableStatistics(Instruction inst, HashMap<String, VarStats> stats) throws DMLRuntimeException {
        MRJobInstruction jobinst = (MRJobInstruction)inst;
        String[] inVars = jobinst.getInputVars();
        int index = -1;
        for (String varname : inVars) {
            void var10_17;
            VarStats varStats = stats.get(varname);
            if (varStats == null) {
                VarStats varStats2 = _unknownStats;
            }
            stats.put(String.valueOf(++index), (VarStats)var10_17);
        }
        String rdInst = jobinst.getIv_randInstructions();
        if (rdInst != null && rdInst.length() > 0) {
            StringTokenizer st = new StringTokenizer(rdInst, "\u2021");
            while (st.hasMoreTokens()) {
                String[] parts = InstructionUtils.getInstructionParts(st.nextToken());
                byte outIndex = Byte.parseByte(parts[2]);
                long l = parts[3].contains("\u00b6") ? -1L : UtilFunctions.parseToLong(parts[3]);
                long clen = parts[4].contains("\u00b6") ? -1L : UtilFunctions.parseToLong(parts[4]);
                long brlen = Long.parseLong(parts[5]);
                long bclen = Long.parseLong(parts[6]);
                long nnz = (long)(Double.parseDouble(parts[9]) * (double)l * (double)clen);
                VarStats vs = new VarStats(l, clen, brlen, bclen, nnz, false);
                stats.put(String.valueOf(outIndex), vs);
            }
        }
        HashMap<Byte, MatrixCharacteristics> dims = new HashMap<Byte, MatrixCharacteristics>();
        for (Map.Entry<String, VarStats> e : stats.entrySet()) {
            if (!UtilFunctions.isIntegerNumber(e.getKey())) continue;
            byte by = Byte.parseByte(e.getKey());
            VarStats vs = e.getValue();
            if (vs == null) continue;
            MatrixCharacteristics mc = new MatrixCharacteristics(vs._rlen, vs._clen, (int)vs._brlen, (int)vs._bclen, (long)vs._nnz);
            dims.put(by, mc);
        }
        String[] instCat = new String[]{jobinst.getIv_randInstructions(), jobinst.getIv_recordReaderInstructions(), jobinst.getIv_instructionsInMapper(), jobinst.getIv_shuffleInstructions(), jobinst.getIv_aggInstructions(), jobinst.getIv_otherInstructions()};
        for (String linstCat : instCat) {
            String[] linst;
            if (linstCat == null || linstCat.length() <= 0) continue;
            for (String instStr : linst = linstCat.split("\u2021")) {
                String instStr2 = this.replaceInstructionPatch(instStr);
                MRInstruction mrinst = MRInstructionParser.parseSingleInstruction(instStr2);
                MatrixCharacteristics.computeDimension(dims, mrinst);
            }
        }
        for (Map.Entry entry : dims.entrySet()) {
            byte ix = (Byte)entry.getKey();
            if (stats.containsKey(String.valueOf(ix))) continue;
            MatrixCharacteristics mc = (MatrixCharacteristics)entry.getValue();
            VarStats vs = new VarStats(mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros(), false);
            stats.put(String.valueOf(ix), vs);
        }
        String[] outLabels = jobinst.getOutputVars();
        byte[] byArray = jobinst.getIv_resultIndices();
        for (int i = 0; i < byArray.length; ++i) {
            String varname = outLabels[i];
            VarStats varvs = stats.get(String.valueOf(byArray[i]));
            if (varvs == null) {
                varvs = stats.get(outLabels[i]);
            }
            varvs._inmem = false;
            stats.put(varname, varvs);
        }
    }

    protected String replaceInstructionPatch(String inst) {
        String ret = inst;
        while (ret.contains("\u00b6")) {
            int index1 = ret.indexOf("\u00b6");
            int index2 = ret.indexOf("\u00b6", index1 + 1);
            String replace = ret.substring(index1, index2 + 1);
            ret = ret.replaceAll(replace, "1");
        }
        return ret;
    }

    private static Object[] extractCPInstStatistics(Instruction inst, HashMap<String, VarStats> stats) {
        Object[] ret = new Object[2];
        VarStats[] vs = new VarStats[3];
        String[] attr = null;
        if (inst instanceof UnaryCPInstruction) {
            if (inst instanceof DataGenCPInstruction) {
                DataGenCPInstruction rinst = (DataGenCPInstruction)inst;
                vs[0] = _unknownStats;
                vs[1] = _unknownStats;
                vs[2] = stats.get(rinst.output.getName());
                int type = 2;
                if (rinst.getMinValue() == 0.0 && rinst.getMaxValue() == 0.0) {
                    type = 0;
                } else if (rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue()) {
                    type = 1;
                }
                attr = new String[]{String.valueOf(type)};
            } else if (inst instanceof StringInitCPInstruction) {
                StringInitCPInstruction rinst = (StringInitCPInstruction)inst;
                vs[0] = _unknownStats;
                vs[1] = _unknownStats;
                vs[2] = stats.get(rinst.output.getName());
            } else {
                String[] parts;
                String opcode;
                UnaryCPInstruction uinst = (UnaryCPInstruction)inst;
                vs[0] = stats.get(uinst.input1.getName());
                vs[1] = _unknownStats;
                vs[2] = stats.get(uinst.output.getName());
                if (vs[0] == null) {
                    vs[0] = _scalarStats;
                }
                if (vs[2] == null) {
                    vs[2] = _scalarStats;
                }
                if (inst instanceof MMTSJCPInstruction) {
                    String type = ((MMTSJCPInstruction)inst).getMMTSJType().toString();
                    attr = new String[]{type};
                } else if (inst instanceof AggregateUnaryCPInstruction && (opcode = (parts = InstructionUtils.getInstructionParts(inst.toString()))[0]).equals("cm")) {
                    attr = new String[]{parts[parts.length - 2]};
                }
            }
        } else if (inst instanceof BinaryCPInstruction) {
            BinaryCPInstruction binst = (BinaryCPInstruction)inst;
            vs[0] = stats.get(binst.input1.getName());
            vs[1] = stats.get(binst.input2.getName());
            vs[2] = stats.get(binst.output.getName());
            if (vs[0] == null) {
                vs[0] = _scalarStats;
            }
            if (vs[1] == null) {
                vs[1] = _scalarStats;
            }
            if (vs[2] == null) {
                vs[2] = _scalarStats;
            }
        } else if (inst instanceof AggregateTernaryCPInstruction) {
            AggregateTernaryCPInstruction binst = (AggregateTernaryCPInstruction)inst;
            vs[0] = stats.get(binst.input1.getName());
            vs[1] = stats.get(binst.input2.getName());
            vs[2] = stats.get(binst.output.getName());
            if (vs[0] == null) {
                vs[0] = _scalarStats;
            }
            if (vs[1] == null) {
                vs[1] = _scalarStats;
            }
            if (vs[2] == null) {
                vs[2] = _scalarStats;
            }
        } else if (inst instanceof ParameterizedBuiltinCPInstruction) {
            String[] parts = InstructionUtils.getInstructionParts(inst.toString());
            String opcode = parts[0];
            if (opcode.equals("groupedagg")) {
                HashMap<String, String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
                String fn = paramsMap.get("fn");
                String order = paramsMap.get("order");
                CMOperator.AggregateOperationTypes type = CMOperator.getAggOpType(fn, order);
                attr = new String[]{String.valueOf(type.ordinal())};
            } else if (opcode.equals("rmempty")) {
                HashMap<String, String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
                attr = new String[]{String.valueOf(paramsMap.get("margin").equals("rows") ? 0 : 1)};
            }
            vs[0] = stats.get(parts[1].substring(7).replaceAll("\u00b6", ""));
            vs[1] = _unknownStats;
            vs[2] = stats.get(parts[parts.length - 1]);
            if (vs[0] == null) {
                vs[0] = _scalarStats;
            }
            if (vs[2] == null) {
                vs[2] = _scalarStats;
            }
        } else if (inst instanceof MultiReturnBuiltinCPInstruction) {
            MultiReturnBuiltinCPInstruction minst = (MultiReturnBuiltinCPInstruction)inst;
            vs[0] = stats.get(minst.input1.getName());
            vs[1] = stats.get(minst.getOutput(0).getName());
            vs[2] = stats.get(minst.getOutput(1).getName());
        } else if (inst instanceof VariableCPInstruction) {
            CostEstimator.setUnknownStats(vs);
            VariableCPInstruction varinst = (VariableCPInstruction)inst;
            if (varinst.getOpcode().equals("write")) {
                if (stats.containsKey(varinst.getInput1().getName())) {
                    vs[0] = stats.get(varinst.getInput1().getName());
                }
                attr = new String[]{varinst.getInput3().getName()};
            }
        } else {
            CostEstimator.setUnknownStats(vs);
        }
        vs[2]._inmem = true;
        ret[0] = vs;
        ret[1] = attr;
        return ret;
    }

    private static void setUnknownStats(VarStats[] vs) {
        vs[0] = _unknownStats;
        vs[1] = _unknownStats;
        vs[2] = _unknownStats;
    }

    private static Object[] extractMRJobInstStatistics(Instruction inst, HashMap<String, VarStats> stats) {
        int i;
        Object[] ret = new Object[2];
        VarStats[] vs = null;
        Object attr = null;
        MRJobInstruction jinst = (MRJobInstruction)inst;
        byte[] indexes = jinst.getIv_resultIndices();
        byte maxIx = -1;
        for (i = 0; i < indexes.length; ++i) {
            if (maxIx >= indexes[i]) continue;
            maxIx = indexes[i];
        }
        vs = new VarStats[maxIx + 1];
        for (i = 0; i < vs.length; ++i) {
            vs[i] = stats.get(String.valueOf(i));
            if (vs[i] != null) continue;
            vs[i] = _unknownStats;
        }
        ret[0] = vs;
        ret[1] = attr;
        return ret;
    }

    private static void cleanupMRJobVariableStatistics(Instruction inst, HashMap<String, VarStats> stats) {
        int i;
        MRJobInstruction jinst = (MRJobInstruction)inst;
        byte[] indexes = jinst.getIv_resultIndices();
        byte maxIx = -1;
        for (i = 0; i < indexes.length; ++i) {
            if (maxIx >= indexes[i]) continue;
            maxIx = indexes[i];
        }
        for (i = 0; i <= maxIx; ++i) {
            VarStats tmp = stats.remove(String.valueOf(i));
            if (tmp == null) continue;
            tmp._inmem = false;
        }
    }

    private static long getNumIterations(HashMap<String, VarStats> stats, ForProgramBlock pb) {
        return OptimizerUtils.getNumIterations(pb, 15L);
    }

    protected abstract double getCPInstTimeEstimate(Instruction var1, VarStats[] var2, String[] var3) throws DMLRuntimeException;

    protected abstract double getMRJobInstTimeEstimate(Instruction var1, VarStats[] var2, String[] var3) throws DMLRuntimeException;
}

