/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.matrix.data.Pair;

public class RewriteSplitDagDataDependentOperators
extends StatementBlockRewriteRule {
    private static String _varnamePredix = "_sbcvar";
    private static IDSequence _seq = new IDSequence();

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) throws HopsException {
        if (DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.SINGLE_NODE || !HopRewriteUtils.isLastLevelStatementBlock(sb)) {
            return new ArrayList<StatementBlock>(Arrays.asList(sb));
        }
        ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
        ArrayList<Hop> cand = new ArrayList<Hop>();
        this.collectDataDependentOperators(sb.getHops(), cand);
        Hop.resetVisitStatus(sb.getHops());
        if (!cand.isEmpty()) {
            HashSet<Hop> candChilds = new HashSet<Hop>();
            this.collectCandidateChildOperators(cand, candChilds);
            try {
                StatementBlock sb1 = new StatementBlock();
                sb1.setDMLProg(sb.getDMLProg());
                sb1.setParseInfo(sb);
                sb1.setLiveIn(new VariableSet());
                sb1.setLiveOut(new VariableSet());
                ArrayList<Hop> sb1hops = new ArrayList<Hop>();
                for (Hop c : cand) {
                    boolean hasTWrites = RewriteSplitDagDataDependentOperators.hasTransientWriteParents(c);
                    boolean moveTWrite = hasTWrites ? HopRewriteUtils.rHasSimpleReadChain(c, RewriteSplitDagDataDependentOperators.getFirstTransientWriteParent(c).getName()) : false;
                    String varname = null;
                    long rlen = c.getDim1();
                    long clen = c.getDim2();
                    long nnz = c.getNnz();
                    MatrixObject.UpdateType update = c.getUpdateType();
                    long brlen = c.getRowsInBlock();
                    long bclen = c.getColsInBlock();
                    if (hasTWrites && moveTWrite) {
                        Hop twrite = RewriteSplitDagDataDependentOperators.getFirstTransientWriteParent(c);
                        varname = twrite.getName();
                        DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), Hop.DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen);
                        tread.setVisited();
                        HopRewriteUtils.copyLineNumbers(c, tread);
                        ArrayList<Hop> parents = new ArrayList<Hop>(c.getParent());
                        for (int i = 0; i < parents.size(); ++i) {
                            Hop parent = parents.get(i);
                            if (candChilds.contains(parent)) continue;
                            if (parent != twrite) {
                                HopRewriteUtils.replaceChildReference(parent, c, tread);
                                continue;
                            }
                            sb.getHops().remove(parent);
                        }
                        sb1hops.add(twrite);
                    } else {
                        varname = _varnamePredix + _seq.getNextID();
                        DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), Hop.DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen);
                        tread.setVisited();
                        HopRewriteUtils.copyLineNumbers(c, tread);
                        ArrayList<Hop> parents = new ArrayList<Hop>(c.getParent());
                        for (int i = 0; i < parents.size(); ++i) {
                            Hop parent = parents.get(i);
                            if (candChilds.contains(parent)) continue;
                            HopRewriteUtils.replaceChildReference(parent, c, tread);
                        }
                        DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, Hop.DataOpTypes.TRANSIENTWRITE, null);
                        twrite.setVisited();
                        twrite.setOutputParams(rlen, clen, nnz, update, brlen, bclen);
                        HopRewriteUtils.copyLineNumbers(c, twrite);
                        sb1hops.add(twrite);
                    }
                    DataIdentifier diVar = new DataIdentifier(varname);
                    diVar.setDimensions(rlen, clen);
                    diVar.setBlockDimensions(brlen, bclen);
                    diVar.setDataType(c.getDataType());
                    diVar.setValueType(c.getValueType());
                    sb1.liveOut().addVariable(varname, new DataIdentifier(diVar));
                    sb.liveIn().addVariable(varname, new DataIdentifier(diVar));
                }
                this.handleReplicatedOperators(sb1hops, sb.getHops(), sb1.liveOut(), sb.liveIn());
                sb1.setHops(Recompiler.deepCopyHopsDag(sb1hops));
                sb1.updateRecompilationFlag();
                sb1.setSplitDag(true);
                List<StatementBlock> tmp = this.rewriteStatementBlock(sb1, state);
                ret.addAll(tmp);
                ret.add(sb);
                sb.setSplitDag(true);
            }
            catch (Exception ex) {
                throw new HopsException("Failed to split hops dag for data dependent operators with unknown size.", ex);
            }
            LOG.debug((Object)("Applied splitDagDataDependentOperators (lines " + sb.getBeginLine() + "-" + sb.getEndLine() + ")."));
        } else {
            ret.add(sb);
        }
        return ret;
    }

    private void collectDataDependentOperators(ArrayList<Hop> roots, ArrayList<Hop> cand) {
        if (roots == null) {
            return;
        }
        Hop.resetVisitStatus(roots);
        for (Hop root : roots) {
            this.rCollectDataDependentOperators(root, cand);
        }
    }

    private void rCollectDataDependentOperators(Hop hop, ArrayList<Hop> cand) {
        if (hop.isVisited()) {
            return;
        }
        boolean noSplitRequired = hop.dimsKnown() || HopRewriteUtils.hasOnlyWriteParents(hop, true, true);
        boolean investigateChilds = true;
        if (!(!(hop instanceof ParameterizedBuiltinOp) || ((ParameterizedBuiltinOp)hop).getOp() != Hop.ParamBuiltinOp.RMEMPTY || noSplitRequired || hop.getParent().size() == 1 && hop.getParent().get(0) instanceof TernaryOp && ((TernaryOp)hop.getParent().get(0)).isMatrixIgnoreZeroRewriteApplicable())) {
            ParameterizedBuiltinOp pbhop = (ParameterizedBuiltinOp)hop;
            cand.add(pbhop);
            investigateChilds = false;
            boolean noEmptyBlocks = true;
            boolean onlyPMM = true;
            boolean diagInput = pbhop.isTargetDiagInput();
            for (Hop p : hop.getParent()) {
                noEmptyBlocks &= p instanceof AggBinaryOp && hop == p.getInput().get(0) || HopRewriteUtils.isUnary(p, Hop.OpOp1.NROW);
                onlyPMM &= p instanceof AggBinaryOp && hop == p.getInput().get(0);
            }
            pbhop.setOutputEmptyBlocks(!noEmptyBlocks);
            if (onlyPMM && diagInput) {
                if (ConfigurationManager.isDynamicRecompilation()) {
                    pbhop.setOutputPermutationMatrix(true);
                }
                for (Hop p : hop.getParent()) {
                    ((AggBinaryOp)p).setHasLeftPMInput(true);
                }
            }
        }
        if (hop instanceof TernaryOp && ((TernaryOp)hop).getOp() == Hop.OpOp3.CTABLE && hop.getInput().size() < 4 && !noSplitRequired) {
            cand.add(hop);
            investigateChilds = false;
            boolean onlyPMM = true;
            for (Hop p : hop.getParent()) {
                onlyPMM &= p instanceof AggBinaryOp && hop == p.getInput().get(0);
            }
            if (onlyPMM && HopRewriteUtils.isBasic1NSequence(hop.getInput().get(0))) {
                hop.setOutputEmptyBlocks(false);
            }
        }
        if (hop instanceof ReorgOp && ((ReorgOp)hop).getOp() == Hop.ReOrgOp.SORT) {
            for (int i = 2; i <= 3; ++i) {
                Hop c = hop.getInput().get(i);
                if (c instanceof LiteralOp || c instanceof DataOp) continue;
                cand.add(c);
                c.setVisited();
                investigateChilds = false;
            }
        }
        if (investigateChilds && hop.getInput() != null) {
            for (Hop c : hop.getInput()) {
                this.rCollectDataDependentOperators(c, cand);
            }
        }
        hop.setVisited();
    }

    private static boolean hasTransientWriteParents(Hop hop) {
        for (Hop p : hop.getParent()) {
            if (!(p instanceof DataOp) || ((DataOp)p).getDataOpType() != Hop.DataOpTypes.TRANSIENTWRITE) continue;
            return true;
        }
        return false;
    }

    private static Hop getFirstTransientWriteParent(Hop hop) {
        for (Hop p : hop.getParent()) {
            if (!(p instanceof DataOp) || ((DataOp)p).getDataOpType() != Hop.DataOpTypes.TRANSIENTWRITE) continue;
            return p;
        }
        return null;
    }

    private void handleReplicatedOperators(ArrayList<Hop> rootsSB1, ArrayList<Hop> rootsSB2, VariableSet sb1out, VariableSet sb2in) {
        HashSet<Hop> probeSet = new HashSet<Hop>();
        Hop.resetVisitStatus(rootsSB1);
        for (Hop h : rootsSB1) {
            this.rAddHopsToProbeSet(h, probeSet);
        }
        HashSet<Pair<Hop, Hop>> candSet = new HashSet<Pair<Hop, Hop>>();
        Hop.resetVisitStatus(rootsSB2);
        for (Hop hop : rootsSB2) {
            this.rProbeAndAddHopsToCandidateSet(hop, probeSet, candSet);
        }
        for (Pair pair : candSet) {
            String varname = _varnamePredix + _seq.getNextID();
            Hop hop = (Hop)pair.getKey();
            Hop c = (Hop)pair.getValue();
            DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), Hop.DataOpTypes.TRANSIENTREAD, null, c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock());
            tread.setVisited();
            HopRewriteUtils.copyLineNumbers(c, tread);
            DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, Hop.DataOpTypes.TRANSIENTWRITE, null);
            twrite.setVisited();
            twrite.setOutputParams(c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock());
            HopRewriteUtils.copyLineNumbers(c, twrite);
            int pos = HopRewriteUtils.getChildReferencePos(hop, c);
            HopRewriteUtils.removeChildReferenceByPos(hop, c, pos);
            HopRewriteUtils.addChildReference(hop, tread, pos);
            DataIdentifier diVar = new DataIdentifier(varname);
            diVar.setDimensions(c.getDim1(), c.getDim2());
            diVar.setBlockDimensions(c.getRowsInBlock(), c.getColsInBlock());
            diVar.setDataType(c.getDataType());
            diVar.setValueType(c.getValueType());
            sb1out.addVariable(varname, new DataIdentifier(diVar));
            sb2in.addVariable(varname, new DataIdentifier(diVar));
            rootsSB1.add(twrite);
        }
    }

    private void rAddHopsToProbeSet(Hop hop, HashSet<Hop> probeSet) {
        if (hop.isVisited()) {
            return;
        }
        if (!(hop instanceof DataOp && !((DataOp)hop).isPersistentReadWrite() || hop instanceof LiteralOp)) {
            probeSet.add(hop);
        }
        if (hop.getInput() != null) {
            for (Hop c : hop.getInput()) {
                this.rAddHopsToProbeSet(c, probeSet);
            }
        }
        hop.setVisited();
    }

    private void rProbeAndAddHopsToCandidateSet(Hop hop, HashSet<Hop> probeSet, HashSet<Pair<Hop, Hop>> candSet) {
        if (hop.isVisited()) {
            return;
        }
        if (hop.getInput() != null) {
            for (Hop c : hop.getInput()) {
                if (!probeSet.contains(c)) {
                    this.rProbeAndAddHopsToCandidateSet(c, probeSet, candSet);
                    continue;
                }
                candSet.add(new Pair<Hop, Hop>(hop, c));
            }
        }
        hop.setVisited();
    }

    private void collectCandidateChildOperators(ArrayList<Hop> cand, HashSet<Hop> candChilds) {
        Hop.resetVisitStatus(cand);
        if (cand != null) {
            for (Hop root : cand) {
                this.rCollectCandidateChildOperators(root, cand, candChilds, false);
            }
        }
        Hop.resetVisitStatus(cand);
    }

    private void rCollectCandidateChildOperators(Hop hop, ArrayList<Hop> cand, HashSet<Hop> candChilds, boolean collect) {
        if (hop.isVisited()) {
            return;
        }
        if (collect) {
            candChilds.add(hop);
        }
        boolean passedFlag = collect;
        if (cand.contains(hop)) {
            passedFlag = true;
        }
        if (hop.getInput() != null) {
            for (Hop c : hop.getInput()) {
                this.rCollectCandidateChildOperators(c, cand, candChilds, passedFlag);
            }
        }
        hop.setVisited();
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) throws HopsException {
        return sbs;
    }
}

