/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.globalopt.gdfgraph;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.globalopt.Summary;
import org.apache.sysml.hops.globalopt.gdfgraph.GDFCrossBlockNode;
import org.apache.sysml.hops.globalopt.gdfgraph.GDFGraph;
import org.apache.sysml.hops.globalopt.gdfgraph.GDFLoopNode;
import org.apache.sysml.hops.globalopt.gdfgraph.GDFNode;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.utils.Explain;

public class GraphBuilder {
    private static final boolean IGNORE_UNBOUND_UPDATED_VARS = true;

    public static GDFGraph constructGlobalDataFlowGraph(Program prog, Summary summary) throws DMLRuntimeException, HopsException {
        Timing time = new Timing(true);
        HashMap<String, GDFNode> roots = new HashMap<String, GDFNode>();
        for (ProgramBlock programBlock : prog.getProgramBlocks()) {
            GraphBuilder.constructGDFGraph(programBlock, roots);
        }
        ArrayList<GDFNode> ret = new ArrayList<GDFNode>();
        for (GDFNode root : roots.values()) {
            if (root instanceof GDFCrossBlockNode) continue;
            ret.add(root);
        }
        GDFGraph gDFGraph = new GDFGraph(prog, ret);
        summary.setTimeGDFGraph(time.stop());
        return gDFGraph;
    }

    private static void constructGDFGraph(ProgramBlock pb, HashMap<String, GDFNode> roots) throws DMLRuntimeException, HopsException {
        if (pb instanceof FunctionProgramBlock) {
            throw new DMLRuntimeException("FunctionProgramBlocks not implemented yet.");
        }
        if (pb instanceof WhileProgramBlock) {
            WhileProgramBlock wpb = (WhileProgramBlock)pb;
            WhileStatementBlock wsb = (WhileStatementBlock)pb.getStatementBlock();
            GDFNode pred = GraphBuilder.constructGDFGraph(wsb.getPredicateHops(), wpb, new HashMap<Long, GDFNode>(), roots);
            HashMap<String, GDFNode> inputs = GraphBuilder.constructLoopInputNodes(wpb, wsb, roots);
            HashMap lroots = (HashMap)inputs.clone();
            for (ProgramBlock pbc : wpb.getChildBlocks()) {
                GraphBuilder.constructGDFGraph(pbc, lroots);
            }
            HashMap<String, GDFNode> outputs = GraphBuilder.constructLoopOutputNodes(wsb, lroots);
            GDFLoopNode lnode = new GDFLoopNode(wpb, pred, inputs, outputs);
            GraphBuilder.constructLoopOutputCrossBlockNodes(wsb, lnode, outputs, roots, wpb);
        } else if (pb instanceof IfProgramBlock) {
            IfProgramBlock ipb = (IfProgramBlock)pb;
            IfStatementBlock isb = (IfStatementBlock)pb.getStatementBlock();
            if (isb.getPredicateHops() != null) {
                Hop pred = isb.getPredicateHops();
                roots.put(pred.getName(), GraphBuilder.constructGDFGraph(pred, ipb, new HashMap<Long, GDFNode>(), roots));
            }
            HashMap ifRoots = (HashMap)roots.clone();
            HashMap elseRoots = (HashMap)roots.clone();
            for (ProgramBlock pbc : ipb.getChildBlocksIfBody()) {
                GraphBuilder.constructGDFGraph(pbc, ifRoots);
            }
            if (ipb.getChildBlocksElseBody() != null) {
                for (ProgramBlock pbc : ipb.getChildBlocksElseBody()) {
                    GraphBuilder.constructGDFGraph(pbc, elseRoots);
                }
            }
            GraphBuilder.reconcileMergeIfProgramBlockOutputs(ifRoots, elseRoots, roots, ipb);
        } else if (pb instanceof ForProgramBlock) {
            ForProgramBlock fpb = (ForProgramBlock)pb;
            ForStatementBlock fsb = (ForStatementBlock)pb.getStatementBlock();
            GDFNode pred = GraphBuilder.constructForPredicateNode(fpb, fsb, roots);
            HashMap<String, GDFNode> inputs = GraphBuilder.constructLoopInputNodes(fpb, fsb, roots);
            HashMap lroots = (HashMap)inputs.clone();
            for (ProgramBlock pbc : fpb.getChildBlocks()) {
                GraphBuilder.constructGDFGraph(pbc, lroots);
            }
            HashMap<String, GDFNode> outputs = GraphBuilder.constructLoopOutputNodes(fsb, lroots);
            GDFLoopNode lnode = new GDFLoopNode(fpb, pred, inputs, outputs);
            GraphBuilder.constructLoopOutputCrossBlockNodes(fsb, lnode, outputs, roots, fpb);
        } else {
            StatementBlock sb = pb.getStatementBlock();
            ArrayList<Hop> hops = sb.get_hops();
            if (hops != null) {
                HashMap<Long, GDFNode> lmemo = new HashMap<Long, GDFNode>();
                for (Hop hop : hops) {
                    GDFNode root = GraphBuilder.constructGDFGraph(hop, pb, lmemo, roots);
                    if (root == null) {
                        throw new HopsException("GDFGraphBuilder: failed to constuct dag root for: " + Explain.explain(hop));
                    }
                    if (hop instanceof DataOp && ((DataOp)hop).getDataOpType() == Hop.DataOpTypes.TRANSIENTWRITE) {
                        root = new GDFCrossBlockNode(hop, pb, root, hop.getName());
                    }
                    roots.put(hop.getName(), root);
                }
            }
        }
    }

    private static GDFNode constructGDFGraph(Hop hop, ProgramBlock pb, HashMap<Long, GDFNode> lmemo, HashMap<String, GDFNode> roots) {
        if (lmemo.containsKey(hop.getHopID())) {
            return lmemo.get(hop.getHopID());
        }
        ArrayList<GDFNode> inputs = new ArrayList<GDFNode>();
        for (Hop c : hop.getInput()) {
            inputs.add(GraphBuilder.constructGDFGraph(c, pb, lmemo, roots));
        }
        if (hop instanceof DataOp && ((DataOp)hop).getDataOpType() == Hop.DataOpTypes.TRANSIENTREAD) {
            inputs.add(roots.get(hop.getName()));
        }
        GDFNode gnode = new GDFNode(hop, pb, inputs);
        lmemo.put(hop.getHopID(), gnode);
        return gnode;
    }

    private static GDFNode constructForPredicateNode(ForProgramBlock fpb, ForStatementBlock fsb, HashMap<String, GDFNode> roots) {
        HashMap<Long, GDFNode> memo = new HashMap<Long, GDFNode>();
        GDFNode from = fsb.getFromHops() != null ? GraphBuilder.constructGDFGraph(fsb.getFromHops(), fpb, memo, roots) : null;
        GDFNode to = fsb.getToHops() != null ? GraphBuilder.constructGDFGraph(fsb.getToHops(), fpb, memo, roots) : null;
        GDFNode incr = fsb.getIncrementHops() != null ? GraphBuilder.constructGDFGraph(fsb.getIncrementHops(), fpb, memo, roots) : null;
        ArrayList<GDFNode> inputs = new ArrayList<GDFNode>();
        inputs.add(from);
        inputs.add(to);
        inputs.add(incr);
        GDFNode pred = new GDFNode(null, fpb, inputs);
        return pred;
    }

    private static HashMap<String, GDFNode> constructLoopInputNodes(ProgramBlock fpb, StatementBlock fsb, HashMap<String, GDFNode> roots) throws DMLRuntimeException {
        HashMap<String, GDFNode> ret = new HashMap<String, GDFNode>();
        Set<String> invars = fsb.variablesRead().getVariableNames();
        for (String var : invars) {
            if (!fsb.liveIn().containsVariable(var)) continue;
            GDFNode node = roots.get(var);
            if (node == null) {
                throw new DMLRuntimeException("GDFGraphBuilder: Non-existing input node for variable: " + var);
            }
            ret.put(var, node);
        }
        return ret;
    }

    private static HashMap<String, GDFNode> constructLoopOutputNodes(StatementBlock fsb, HashMap<String, GDFNode> roots) throws HopsException {
        HashMap<String, GDFNode> ret = new HashMap<String, GDFNode>();
        Set<String> outvars = fsb.variablesUpdated().getVariableNames();
        for (String var : outvars) {
            GDFNode node = roots.get(var);
            if (node == null) continue;
            ret.put(var, node);
        }
        return ret;
    }

    private static void reconcileMergeIfProgramBlockOutputs(HashMap<String, GDFNode> ifRoots, HashMap<String, GDFNode> elseRoots, HashMap<String, GDFNode> roots, IfProgramBlock pb) {
        for (Map.Entry<String, GDFNode> e : ifRoots.entrySet()) {
            GDFNode node2;
            GDFNode node1 = e.getValue();
            if (node1 != (node2 = elseRoots.get(e.getKey()))) {
                node1 = new GDFCrossBlockNode(null, pb, node1, node2, e.getKey());
            }
            roots.put(e.getKey(), node1);
        }
        for (Map.Entry<String, GDFNode> e : elseRoots.entrySet()) {
            if (ifRoots.containsKey(e.getKey())) continue;
            roots.put(e.getKey(), e.getValue());
        }
    }

    private static void constructLoopOutputCrossBlockNodes(StatementBlock sb, GDFLoopNode loop, HashMap<String, GDFNode> loutputs, HashMap<String, GDFNode> roots, ProgramBlock pb) {
        for (Map.Entry<String, GDFNode> e : loutputs.entrySet()) {
            if (!sb.liveOut().containsVariable(e.getKey())) continue;
            GDFCrossBlockNode node = null;
            node = roots.containsKey(e.getKey()) ? new GDFCrossBlockNode(null, pb, roots.get(e.getKey()), loop, e.getKey()) : new GDFCrossBlockNode(null, pb, loop, e.getKey());
            roots.put(e.getKey(), node);
        }
    }
}

