/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysml.lops.Compression;
import org.apache.sysml.lops.MMTSJ;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;

public class RewriteCompressedReblock
extends StatementBlockRewriteRule {
    private static final String TMP_PREFIX = "__cmtx";

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus sate) throws HopsException {
        if (!HopRewriteUtils.isLastLevelStatementBlock(sb) || sb.getHops() == null) {
            return Arrays.asList(sb);
        }
        DMLConfig conf = ConfigurationManager.getDMLConfig();
        Compression.CompressConfig compress = Compression.CompressConfig.valueOf(conf.getTextValue("sysml.compressed.linalg").toUpperCase());
        if (compress.isEnabled()) {
            Hop.resetVisitStatus(sb.getHops());
            for (Hop h : sb.getHops()) {
                RewriteCompressedReblock.injectCompressionDirective(h, compress, sb.getDMLProg());
            }
            Hop.resetVisitStatus(sb.getHops());
        }
        return Arrays.asList(sb);
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) throws HopsException {
        return sbs;
    }

    private static void injectCompressionDirective(Hop hop, Compression.CompressConfig compress, DMLProgram prog) throws HopsException {
        if (hop.isVisited() || hop.requiresCompression()) {
            return;
        }
        for (Hop hi : hop.getInput()) {
            RewriteCompressedReblock.injectCompressionDirective(hi, compress, prog);
        }
        if (compress == Compression.CompressConfig.TRUE && RewriteCompressedReblock.satisfiesCompressionCondition(hop) || compress == Compression.CompressConfig.AUTO && RewriteCompressedReblock.satisfiesAutoCompressionCondition(hop, prog)) {
            hop.setRequiresCompression(true);
        }
        hop.setVisited();
    }

    private static boolean satisfiesCompressionCondition(Hop hop) {
        return HopRewriteUtils.isData(hop, Hop.DataOpTypes.PERSISTENTREAD) && hop.getDim1() > 1L && hop.getDim2() > 1L;
    }

    private static boolean satisfiesAutoCompressionCondition(Hop hop, DMLProgram prog) throws HopsException {
        boolean ultraSparse;
        double cacheSize;
        if (!RewriteCompressedReblock.satisfiesCompressionCondition(hop) || !OptimizerUtils.isSparkExecutionMode()) {
            return false;
        }
        double matrixPSize = OptimizerUtils.estimatePartitionedSizeExactSparsity(hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz());
        boolean outOfCore = matrixPSize > (cacheSize = SparkExecutionContext.getDataMemoryBudget(true, true));
        double sparsity = OptimizerUtils.getSparsity(hop.getDim1(), hop.getDim2(), hop.getNnz());
        boolean bl = ultraSparse = sparsity < 4.0E-5;
        if (hop.dimsKnown(true) && outOfCore && !ultraSparse) {
            boolean ret;
            ProbeStatus status = new ProbeStatus(hop.getHopID(), prog);
            for (StatementBlock sb : prog.getStatementBlocks()) {
                RewriteCompressedReblock.rAnalyzeProgram(sb, status);
            }
            boolean bl2 = ret = status.foundStart && status.usedInLoop && !status.condUpdate && !status.nonApplicable;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Auto compression: " + ret + " (dimsKnown=" + hop.dimsKnown(true) + ", outOfCore=" + outOfCore + ", !ultraSparse=" + !ultraSparse + ", foundStart=" + status.foundStart + ", usedInLoop=" + status.foundStart + ", !condUpdate=" + !status.condUpdate + ", !nonApplicable=" + !status.nonApplicable + ")"));
            }
            return ret;
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Auto compression: false (dimsKnown=" + hop.dimsKnown(true) + ", outOfCore=" + outOfCore + ", !ultraSparse=" + !ultraSparse + ")"));
        }
        return false;
    }

    private static void rAnalyzeProgram(StatementBlock sb, ProbeStatus status) throws HopsException {
        if (sb instanceof FunctionStatementBlock) {
            FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            for (StatementBlock csb : fstmt.getBody()) {
                RewriteCompressedReblock.rAnalyzeProgram(csb, status);
            }
        } else if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock)sb;
            WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
            for (StatementBlock csb : wstmt.getBody()) {
                RewriteCompressedReblock.rAnalyzeProgram(csb, status);
            }
            if (wsb.variablesRead().containsAnyName(status.compMtx)) {
                status.usedInLoop = true;
            }
        } else if (sb instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock)sb;
            IfStatement istmt = (IfStatement)isb.getStatement(0);
            for (StatementBlock csb : istmt.getIfBody()) {
                RewriteCompressedReblock.rAnalyzeProgram(csb, status);
            }
            for (StatementBlock csb : istmt.getElseBody()) {
                RewriteCompressedReblock.rAnalyzeProgram(csb, status);
            }
            if (isb.variablesUpdated().containsAnyName(status.compMtx)) {
                status.condUpdate = true;
            }
        } else if (sb instanceof ForStatementBlock) {
            ForStatementBlock fsb = (ForStatementBlock)sb;
            ForStatement fstmt = (ForStatement)fsb.getStatement(0);
            for (StatementBlock csb : fstmt.getBody()) {
                RewriteCompressedReblock.rAnalyzeProgram(csb, status);
            }
            if (fsb.variablesRead().containsAnyName(status.compMtx)) {
                status.usedInLoop = true;
            }
        } else if (sb.getHops() != null) {
            ArrayList<Hop> roots = sb.getHops();
            Hop.resetVisitStatus(roots);
            for (Hop root : roots) {
                RewriteCompressedReblock.rAnalyzeHopDag(root, status);
            }
            status.compMtx.removeIf(n -> n.startsWith(TMP_PREFIX));
            Hop.resetVisitStatus(roots);
        }
    }

    private static void rAnalyzeHopDag(Hop current, ProbeStatus status) throws HopsException {
        if (current.isVisited()) {
            return;
        }
        for (Hop input : current.getInput()) {
            RewriteCompressedReblock.rAnalyzeHopDag(input, status);
        }
        if (current.getHopID() == status.startHopID) {
            status.compMtx.add(RewriteCompressedReblock.getTmpName(current));
            status.foundStart = true;
        }
        if (current instanceof FunctionOp && RewriteCompressedReblock.hasCompressedInput(current, status)) {
            FunctionOp fop = (FunctionOp)current;
            String fkey = fop.getFunctionKey();
            if (!status.procFn.contains(fkey)) {
                status.procFn.add(fkey);
                FunctionStatementBlock fsb = status.prog.getFunctionStatementBlock(fkey);
                FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
                ProbeStatus status2 = new ProbeStatus(status);
                for (int i = 0; i < fop.getInput().size(); ++i) {
                    if (!status.compMtx.contains(RewriteCompressedReblock.getTmpName(fop.getInput().get(i)))) continue;
                    status2.compMtx.add(fstmt.getInputParams().get(i).getName());
                }
                RewriteCompressedReblock.rAnalyzeProgram(fsb, status2);
                ProbeStatus i = status;
                i.foundStart = i.foundStart | status2.foundStart;
                i = status;
                i.usedInLoop = i.usedInLoop | status2.usedInLoop;
                i = status;
                i.condUpdate = i.condUpdate | status2.condUpdate;
                i = status;
                i.nonApplicable = i.nonApplicable | status2.nonApplicable;
                String[] outputs = fop.getOutputVariableNames();
                for (int i2 = 0; i2 < outputs.length; ++i2) {
                    if (!status2.compMtx.contains(fstmt.getOutputParams().get(i2).getName())) continue;
                    status.compMtx.add(outputs[i2]);
                }
            }
        } else if (HopRewriteUtils.isData(current, Hop.DataOpTypes.TRANSIENTWRITE) && status.compMtx.contains(RewriteCompressedReblock.getTmpName(current.getInput().get(0)))) {
            status.compMtx.add(current.getName());
        } else if (HopRewriteUtils.isData(current, Hop.DataOpTypes.TRANSIENTREAD) && status.compMtx.contains(current.getName())) {
            status.compMtx.add(RewriteCompressedReblock.getTmpName(current));
        } else if (RewriteCompressedReblock.hasCompressedInput(current, status)) {
            boolean compUCOut = current instanceof AggBinaryOp && current.getDim2() <= current.getColsInBlock() && ((AggBinaryOp)current).checkTransposeSelf() == MMTSJ.MMTSJType.LEFT || current instanceof AggBinaryOp && (current.getDim1() == 1L || current.getDim2() == 1L) || HopRewriteUtils.isTransposeOperation(current) && current.getParent().size() == 1 && current.getParent().get(0) instanceof AggBinaryOp && (current.getParent().get(0).getDim1() == 1L || current.getParent().get(0).getDim2() == 1L) || HopRewriteUtils.isAggUnaryOp(current, Hop.AggOp.SUM, Hop.AggOp.SUM_SQ, Hop.AggOp.MIN, Hop.AggOp.MAX);
            boolean compCOut = HopRewriteUtils.isBinaryMatrixScalarOperation(current) || HopRewriteUtils.isBinary(current, Hop.OpOp2.CBIND);
            boolean metaOp = HopRewriteUtils.isUnary(current, Hop.OpOp1.NROW, Hop.OpOp1.NCOL);
            ProbeStatus probeStatus = status;
            probeStatus.nonApplicable = probeStatus.nonApplicable | (!compUCOut && !compCOut && !metaOp);
            if (compCOut) {
                status.compMtx.add(RewriteCompressedReblock.getTmpName(current));
            }
        }
        current.setVisited();
    }

    private static String getTmpName(Hop hop) {
        return TMP_PREFIX + hop.getHopID();
    }

    private static boolean hasCompressedInput(Hop hop, ProbeStatus status) {
        if (status.compMtx.isEmpty()) {
            return false;
        }
        for (Hop input : hop.getInput()) {
            if (!status.compMtx.contains(RewriteCompressedReblock.getTmpName(input))) continue;
            return true;
        }
        return false;
    }

    private static class ProbeStatus {
        private final long startHopID;
        private final DMLProgram prog;
        private boolean foundStart = false;
        private boolean usedInLoop = false;
        private boolean condUpdate = false;
        private boolean nonApplicable = false;
        private HashSet<String> procFn = new HashSet();
        private HashSet<String> compMtx = new HashSet();

        public ProbeStatus(long hopID, DMLProgram p) {
            this.startHopID = hopID;
            this.prog = p;
        }

        public ProbeStatus(ProbeStatus status) {
            this.startHopID = status.startHopID;
            this.prog = status.prog;
            this.foundStart = status.foundStart;
            this.usedInLoop = status.usedInLoop;
            this.condUpdate = status.condUpdate;
            this.nonApplicable = status.nonApplicable;
            this.procFn.addAll(status.procFn);
        }
    }
}

