/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.ipa;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ipa.FunctionCallGraph;
import org.apache.sysml.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysml.hops.ipa.IPAPass;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.StatementBlock;

public class IPAPassInlineFunctions
extends IPAPass {
    @Override
    public boolean isApplicable() {
        return true;
    }

    @Override
    public void rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) throws HopsException {
        for (String fkey : fgraph.getReachableFunctions()) {
            FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fkey);
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            if (fstmt.getBody().size() != 1 || !HopRewriteUtils.isLastLevelStatementBlock(fstmt.getBody().get(0)) || IPAPassInlineFunctions.containsFunctionOp(fstmt.getBody().get(0).getHops()) || fgraph.getFunctionCalls(fkey).size() != 1 && IPAPassInlineFunctions.countOperators(fstmt.getBody().get(0).getHops()) > 10) continue;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("IPA: Inline function '" + fkey + "'"));
            }
            ArrayList<Hop> hops = fstmt.getBody().get(0).getHops();
            List<FunctionOp> fcalls = fgraph.getFunctionCalls(fkey);
            List<StatementBlock> fcallsSB = fgraph.getFunctionCallsSB(fkey);
            for (int i = 0; i < fcalls.size(); ++i) {
                int j;
                FunctionOp op = fcalls.get(i);
                if (op.getInput().size() != fstmt.getInputParams().size() || op.getOutputVariableNames().length != fstmt.getOutputParams().size()) continue;
                ArrayList<Hop> hops2 = Recompiler.deepCopyHopsDag(hops);
                HashMap<String, Hop> inMap = new HashMap<String, Hop>();
                for (int j2 = 0; j2 < op.getInput().size(); ++j2) {
                    inMap.put(fstmt.getInputParams().get(j2).getName(), op.getInput().get(j2));
                }
                IPAPassInlineFunctions.replaceTransientReads(hops2, inMap);
                HashMap<String, String> outMap = new HashMap<String, String>();
                String[] opOutputs = op.getOutputVariableNames();
                for (j = 0; j < opOutputs.length; ++j) {
                    outMap.put(fstmt.getOutputParams().get(j).getName(), opOutputs[j]);
                }
                for (j = 0; j < hops2.size(); ++j) {
                    Hop out = hops2.get(j);
                    if (!HopRewriteUtils.isData(out, Hop.DataOpTypes.TRANSIENTWRITE)) continue;
                    out.setName((String)outMap.get(out.getName()));
                }
                fcallsSB.get(i).getHops().remove(op);
                fcallsSB.get(i).getHops().addAll(hops2);
            }
            fgraph.removeFunctionCalls(fkey);
        }
    }

    private static boolean containsFunctionOp(ArrayList<Hop> hops) {
        if (hops == null || hops.isEmpty()) {
            return false;
        }
        Hop.resetVisitStatus(hops);
        boolean ret = HopRewriteUtils.containsOp(hops, FunctionOp.class);
        Hop.resetVisitStatus(hops);
        return ret;
    }

    private static int countOperators(ArrayList<Hop> hops) {
        if (hops == null || hops.isEmpty()) {
            return 0;
        }
        Hop.resetVisitStatus(hops);
        int count = 0;
        for (Hop hop : hops) {
            count += IPAPassInlineFunctions.rCountOperators(hop);
        }
        Hop.resetVisitStatus(hops);
        return count;
    }

    private static int rCountOperators(Hop current) {
        if (current.isVisited()) {
            return 0;
        }
        int count = !(current instanceof DataOp) && !(current instanceof LiteralOp) ? 1 : 0;
        for (Hop c : current.getInput()) {
            count += IPAPassInlineFunctions.rCountOperators(c);
        }
        current.setVisited();
        return count;
    }

    private static void replaceTransientReads(ArrayList<Hop> hops, HashMap<String, Hop> inMap) {
        Hop.resetVisitStatus(hops);
        for (Hop hop : hops) {
            IPAPassInlineFunctions.rReplaceTransientReads(hop, inMap);
        }
        Hop.resetVisitStatus(hops);
    }

    private static void rReplaceTransientReads(Hop current, HashMap<String, Hop> inMap) {
        if (current.isVisited()) {
            return;
        }
        for (int i = 0; i < current.getInput().size(); ++i) {
            Hop c = current.getInput().get(i);
            IPAPassInlineFunctions.rReplaceTransientReads(c, inMap);
            if (!HopRewriteUtils.isData(c, Hop.DataOpTypes.TRANSIENTREAD)) continue;
            HopRewriteUtils.replaceChildReference(current, c, inMap.get(c.getName()));
        }
        current.setVisited();
    }
}

