/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.template;

import java.util.ArrayList;
import java.util.HashMap;
import org.apache.commons.lang.ArrayUtils;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.codegen.template.TemplateCell;
import org.apache.sysml.hops.codegen.template.TemplateMultiAgg;
import org.apache.sysml.hops.codegen.template.TemplateOuterProduct;
import org.apache.sysml.hops.codegen.template.TemplateRow;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.codegen.SpoofCellwise;
import org.apache.sysml.runtime.codegen.SpoofOuterProduct;
import org.apache.sysml.runtime.codegen.SpoofRowwise;
import org.apache.sysml.runtime.util.UtilFunctions;

public class TemplateUtils {
    public static final TemplateBase[] TEMPLATES = new TemplateBase[]{new TemplateRow(), new TemplateCell(), new TemplateOuterProduct()};

    public static boolean isVector(Hop hop) {
        return hop.getDataType() == Expression.DataType.MATRIX && (hop.getDim1() != 1L && hop.getDim2() == 1L || hop.getDim1() == 1L && hop.getDim2() != 1L);
    }

    public static boolean isColVector(Hop hop) {
        return hop.getDataType() == Expression.DataType.MATRIX && hop.getDim1() != 1L && hop.getDim2() == 1L;
    }

    public static boolean isColVector(CNode hop) {
        return hop.getDataType() == Expression.DataType.MATRIX && hop.getNumRows() != 1L && hop.getNumCols() == 1L;
    }

    public static boolean isRowVector(CNode hop) {
        return hop.getDataType() == Expression.DataType.MATRIX && hop.getNumRows() == 1L && hop.getNumCols() != 1L;
    }

    public static boolean isMatrix(CNode hop) {
        return hop.getDataType() == Expression.DataType.MATRIX && hop.getNumRows() != 1L && hop.getNumCols() != 1L;
    }

    public static CNode wrapLookupIfNecessary(CNode node, Hop hop) {
        CNode ret = node;
        if (TemplateUtils.isColVector(node)) {
            ret = new CNodeUnary(node, CNodeUnary.UnaryType.LOOKUP_R);
        } else if (TemplateUtils.isRowVector(node)) {
            ret = new CNodeUnary(node, CNodeUnary.UnaryType.LOOKUP_C);
        } else if (node instanceof CNodeData && hop.getDataType().isMatrix()) {
            ret = new CNodeUnary(node, CNodeUnary.UnaryType.LOOKUP_RC);
        }
        return ret;
    }

    public static boolean isMatrix(Hop hop) {
        return hop.getDataType() == Expression.DataType.MATRIX && hop.getDim1() != 1L && hop.getDim2() != 1L;
    }

    public static boolean isVectorOrScalar(Hop hop) {
        return hop.dimsKnown() && (hop.getDataType() == Expression.DataType.SCALAR || TemplateUtils.isVector(hop));
    }

    public static boolean isBinaryMatrixRowVector(Hop hop) {
        if (!(hop instanceof BinaryOp)) {
            return false;
        }
        Hop left = hop.getInput().get(0);
        Hop right = hop.getInput().get(1);
        return left.dimsKnown() && right.dimsKnown() && left.getDataType().isMatrix() && right.getDataType().isMatrix() && left.getDim1() > right.getDim1();
    }

    public static boolean isBinaryMatrixColVector(Hop hop) {
        if (!(hop instanceof BinaryOp)) {
            return false;
        }
        Hop left = hop.getInput().get(0);
        Hop right = hop.getInput().get(1);
        return left.dimsKnown() && right.dimsKnown() && left.getDataType().isMatrix() && right.getDataType().isMatrix() && left.getDim2() > right.getDim2();
    }

    public static boolean hasMatrixInput(Hop hop) {
        for (Hop c : hop.getInput()) {
            if (!TemplateUtils.isMatrix(c)) continue;
            return true;
        }
        return false;
    }

    public static boolean isOperationSupported(Hop h) {
        if (h instanceof UnaryOp) {
            return CNodeUnary.UnaryType.contains(((UnaryOp)h).getOp().name());
        }
        if (h instanceof BinaryOp) {
            return CNodeBinary.BinType.contains(((BinaryOp)h).getOp().name());
        }
        if (h instanceof TernaryOp) {
            return CNodeTernary.TernaryType.contains(((TernaryOp)h).getOp().name());
        }
        if (h instanceof ParameterizedBuiltinOp) {
            return CNodeTernary.TernaryType.contains(((ParameterizedBuiltinOp)h).getOp().name());
        }
        return false;
    }

    public static TemplateBase createTemplate(TemplateBase.TemplateType type) {
        return TemplateUtils.createTemplate(type, TemplateBase.CloseType.OPEN_VALID);
    }

    public static TemplateBase createTemplate(TemplateBase.TemplateType type, TemplateBase.CloseType ctype) {
        TemplateBase tpl = null;
        switch (type) {
            case CELL: {
                tpl = new TemplateCell(ctype);
                break;
            }
            case ROW: {
                tpl = new TemplateRow(ctype);
                break;
            }
            case MAGG: {
                tpl = new TemplateMultiAgg(ctype);
                break;
            }
            case OUTER: {
                tpl = new TemplateOuterProduct(ctype);
            }
        }
        return tpl;
    }

    public static TemplateBase[] createCompatibleTemplates(TemplateBase.TemplateType type, TemplateBase.CloseType ctype) {
        TemplateBase[] tpl = null;
        switch (type) {
            case CELL: {
                tpl = new TemplateBase[]{new TemplateCell(ctype), new TemplateRow(ctype)};
                break;
            }
            case ROW: {
                tpl = new TemplateBase[]{new TemplateRow(ctype)};
                break;
            }
            case MAGG: {
                tpl = new TemplateBase[]{new TemplateMultiAgg(ctype)};
                break;
            }
            case OUTER: {
                tpl = new TemplateBase[]{new TemplateOuterProduct(ctype)};
            }
        }
        return tpl;
    }

    public static SpoofCellwise.CellType getCellType(Hop hop) {
        if (hop instanceof AggBinaryOp) {
            return SpoofCellwise.CellType.FULL_AGG;
        }
        if (hop instanceof AggUnaryOp) {
            switch (((AggUnaryOp)hop).getDirection()) {
                case Row: {
                    return SpoofCellwise.CellType.ROW_AGG;
                }
                case Col: {
                    return SpoofCellwise.CellType.COL_AGG;
                }
                case RowCol: {
                    return SpoofCellwise.CellType.FULL_AGG;
                }
            }
        }
        return SpoofCellwise.CellType.NO_AGG;
    }

    public static SpoofRowwise.RowType getRowType(Hop output, Hop ... inputs) {
        Hop B1;
        Hop X = inputs[0];
        Hop hop = B1 = inputs.length > 1 ? inputs[1] : null;
        if (X != null && HopRewriteUtils.isEqualSize(output, X) || X == null || !X.dimsKnown()) {
            return SpoofRowwise.RowType.NO_AGG;
        }
        if ((B1 != null && output.getDim1() == X.getDim1() && output.getDim2() == B1.getDim2() || output instanceof IndexingOp && HopRewriteUtils.isColumnRangeIndexing((IndexingOp)output)) && (!(output instanceof AggBinaryOp) || !HopRewriteUtils.isTransposeOfItself(output.getInput().get(0), X))) {
            return SpoofRowwise.RowType.NO_AGG_B1;
        }
        if (!(output.getDim1() != X.getDim1() || output.getDim2() != 1L || output instanceof AggBinaryOp && HopRewriteUtils.isTransposeOfItself(output.getInput().get(0), X))) {
            return SpoofRowwise.RowType.ROW_AGG;
        }
        if (output instanceof AggUnaryOp && ((AggUnaryOp)output).getDirection() == Hop.Direction.RowCol) {
            return SpoofRowwise.RowType.FULL_AGG;
        }
        if (output.getDim1() == X.getDim2() && output.getDim2() == 1L) {
            return SpoofRowwise.RowType.COL_AGG_T;
        }
        if (output.getDim1() == 1L && output.getDim2() == X.getDim2()) {
            return SpoofRowwise.RowType.COL_AGG;
        }
        if (B1 != null && output.getDim1() == X.getDim2() && output.getDim2() == B1.getDim2()) {
            return SpoofRowwise.RowType.COL_AGG_B1_T;
        }
        if (B1 != null && output.getDim1() == B1.getDim2() && output.getDim2() == X.getDim2()) {
            return SpoofRowwise.RowType.COL_AGG_B1;
        }
        if (B1 != null && output.getDim1() == 1L && B1.getDim2() == output.getDim2()) {
            return SpoofRowwise.RowType.COL_AGG_B1R;
        }
        if (X.getDim1() == output.getDim1() && X.getDim2() != output.getDim2()) {
            return SpoofRowwise.RowType.NO_AGG_CONST;
        }
        if (output.getDim1() == 1L && X.getDim2() != output.getDim2()) {
            return SpoofRowwise.RowType.COL_AGG_CONST;
        }
        throw new RuntimeException("Unknown row type for hop " + output.getHopID() + ".");
    }

    public static Hop.AggOp getAggOp(Hop hop) {
        return hop instanceof AggUnaryOp ? ((AggUnaryOp)hop).getOp() : (hop instanceof AggBinaryOp ? Hop.AggOp.SUM : null);
    }

    public static SpoofOuterProduct.OutProdType getOuterProductType(Hop X, Hop U, Hop V, Hop out) {
        if (out.getDataType() == Expression.DataType.SCALAR) {
            return SpoofOuterProduct.OutProdType.AGG_OUTER_PRODUCT;
        }
        if (out instanceof AggBinaryOp && (out.getInput().get(0) == U || HopRewriteUtils.isTransposeOperation(out.getInput().get(0)) && out.getInput().get(0).getInput().get(0) == U) || HopRewriteUtils.isTransposeOperation(out)) {
            return SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT;
        }
        if (out instanceof AggBinaryOp && (out.getInput().get(1) == V || HopRewriteUtils.isTransposeOperation(out.getInput().get(1)) && out.getInput().get(1).getInput().get(0) == V)) {
            return SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT;
        }
        if (out instanceof BinaryOp && HopRewriteUtils.isEqualSize(out.getInput().get(0), out.getInput().get(1))) {
            return SpoofOuterProduct.OutProdType.CELLWISE_OUTER_PRODUCT;
        }
        throw new RuntimeException("Undefined outer product type for hop " + out.getHopID());
    }

    public static boolean isLookup(CNode node, boolean includeRC1) {
        return TemplateUtils.isUnary(node, CNodeUnary.UnaryType.LOOKUP_C, CNodeUnary.UnaryType.LOOKUP_RC) || includeRC1 && TemplateUtils.isUnary(node, CNodeUnary.UnaryType.LOOKUP_R) || includeRC1 && TemplateUtils.isTernary(node, CNodeTernary.TernaryType.LOOKUP_RC1);
    }

    public static boolean isUnary(CNode node, CNodeUnary.UnaryType ... types) {
        return node instanceof CNodeUnary && ArrayUtils.contains((Object[])types, (Object)((CNodeUnary)node).getType());
    }

    public static boolean isBinary(CNode node, CNodeBinary.BinType ... types) {
        return node instanceof CNodeBinary && ArrayUtils.contains((Object[])types, (Object)((CNodeBinary)node).getType());
    }

    public static boolean rIsSparseSafeOnly(CNode node, CNodeBinary.BinType ... types) {
        if (!(TemplateUtils.isBinary(node, types) || node instanceof CNodeData || node instanceof CNodeUnary && (((CNodeUnary)node).getType().isScalarLookup() || ((CNodeUnary)node).getType().isSparseSafeScalar() || ((CNodeUnary)node).getType() == CNodeUnary.UnaryType.POW2 || ((CNodeUnary)node).getType() == CNodeUnary.UnaryType.MULT2))) {
            return false;
        }
        boolean ret = true;
        for (CNode c : node.getInput()) {
            ret &= TemplateUtils.rIsSparseSafeOnly(c, types);
        }
        return ret;
    }

    public static boolean rContainsInput(CNode node, long hopID) {
        boolean ret = false;
        for (CNode c : node.getInput()) {
            ret |= TemplateUtils.rContainsInput(c, hopID);
        }
        if (node instanceof CNodeData) {
            ret |= ((CNodeData)node).getHopID() == hopID;
        }
        return ret;
    }

    public static boolean isTernary(CNode node, CNodeTernary.TernaryType ... types) {
        return node instanceof CNodeTernary && ArrayUtils.contains((Object[])types, (Object)((CNodeTernary)node).getType());
    }

    public static CNodeData createCNodeData(Hop hop, boolean compileLiterals) {
        CNodeData cdata = new CNodeData(hop);
        cdata.setLiteral(hop instanceof LiteralOp && (compileLiterals || UtilFunctions.isIntegerNumber(((LiteralOp)hop).getStringValue())));
        return cdata;
    }

    public static CNode skipTranspose(CNode cdataOrig, Hop hop, HashMap<Long, CNode> tmp, boolean compileLiterals) {
        if (HopRewriteUtils.isTransposeOperation(hop)) {
            CNode cdata = tmp.get(hop.getInput().get(0).getHopID());
            if (cdata == null) {
                cdata = TemplateUtils.createCNodeData(hop.getInput().get(0), compileLiterals);
                tmp.put(hop.getInput().get(0).getHopID(), cdata);
            }
            tmp.put(hop.getHopID(), cdata);
            return cdata;
        }
        return cdataOrig;
    }

    public static boolean hasTransposeParentUnderOuterProduct(Hop hop) {
        for (Hop p : hop.getParent()) {
            if (!HopRewriteUtils.isTransposeOperation(p)) continue;
            for (Hop p2 : p.getParent()) {
                if (!HopRewriteUtils.isOuterProductLikeMM(p2)) continue;
                return true;
            }
        }
        return false;
    }

    public static boolean hasSingleOperation(CNodeTpl tpl) {
        CNode output = tpl.getOutput();
        return (output instanceof CNodeUnary && !TemplateUtils.isUnary(output, CNodeUnary.UnaryType.EXP, CNodeUnary.UnaryType.LOG, CNodeUnary.UnaryType.ROW_COUNTNNZS) || output instanceof CNodeBinary && !TemplateUtils.isBinary(output, CNodeBinary.BinType.VECT_OUTERMULT_ADD) || output instanceof CNodeTernary && ((CNodeTernary)output).getType() == CNodeTernary.TernaryType.IFELSE) && TemplateUtils.hasOnlyDataNodeOrLookupInputs(output);
    }

    public static boolean hasNoOperation(CNodeTpl tpl) {
        return tpl.getOutput() instanceof CNodeData || TemplateUtils.isLookup(tpl.getOutput(), true);
    }

    public static boolean hasOnlyDataNodeOrLookupInputs(CNode node) {
        boolean ret = true;
        for (CNode c : node.getInput()) {
            ret &= c instanceof CNodeData || c instanceof CNodeUnary && (((CNodeUnary)c).getType() == CNodeUnary.UnaryType.LOOKUP0 || ((CNodeUnary)c).getType() == CNodeUnary.UnaryType.LOOKUP_R || ((CNodeUnary)c).getType() == CNodeUnary.UnaryType.LOOKUP_RC);
        }
        return ret;
    }

    public static int determineMinVectorIntermediates(CNode node) {
        node.resetVisitStatus();
        boolean unaryPipe = TemplateUtils.isUnaryOperatorPipeline(node);
        node.resetVisitStatus();
        int count = unaryPipe ? TemplateUtils.getMaxVectorIntermediates(node) : TemplateUtils.countVectorIntermediates(node);
        node.resetVisitStatus();
        return count;
    }

    public static boolean isUnaryOperatorPipeline(CNode node) {
        if (node.isVisited()) {
            return !(node instanceof CNodeBinary) || !((CNodeBinary)node).getType().isVectorPrimitive();
        }
        boolean ret = true;
        for (CNode input : node.getInput()) {
            ret &= TemplateUtils.isUnaryOperatorPipeline(input);
        }
        node.setVisited();
        return ret;
    }

    public static int getMaxVectorIntermediates(CNode node) {
        if (node.isVisited()) {
            return 0;
        }
        int max = 0;
        for (CNode input : node.getInput()) {
            max = Math.max(max, TemplateUtils.getMaxVectorIntermediates(input));
        }
        max = Math.max(max, node instanceof CNodeTernary && ((CNodeTernary)node).getType().isVectorPrimitive() ? 1 : 0);
        max = Math.max(max, node instanceof CNodeBinary ? (((CNodeBinary)node).getType().isVectorVectorPrimitive() ? 3 : (((CNodeBinary)node).getType().isVectorScalarPrimitive() ? 2 : (((CNodeBinary)node).getType().isVectorMatrixPrimitive() ? 1 : 0))) : 0);
        max = Math.max(max, node instanceof CNodeUnary && ((CNodeUnary)node).getType().isVectorScalarPrimitive() ? 2 : 0);
        node.setVisited();
        return max;
    }

    public static int countVectorIntermediates(CNode node) {
        if (node.isVisited()) {
            return 0;
        }
        node.setVisited();
        int ret = 0;
        for (CNode c : node.getInput()) {
            ret += TemplateUtils.countVectorIntermediates(c);
        }
        int cntBin = node instanceof CNodeBinary && ((CNodeBinary)node).getType().isVectorPrimitive() && !((CNodeBinary)node).getType().name().endsWith("_ADD") ? 1 : 0;
        int cntUn = node instanceof CNodeUnary && ((CNodeUnary)node).getType().isVectorScalarPrimitive() ? 1 : 0;
        int cntTn = node instanceof CNodeTernary && ((CNodeTernary)node).getType().isVectorPrimitive() ? 1 : 0;
        return ret + cntBin + cntUn + cntTn;
    }

    public static boolean isType(TemplateBase.TemplateType type, TemplateBase.TemplateType ... validTypes) {
        return ArrayUtils.contains((Object[])validTypes, (Object)type);
    }

    public static boolean hasCommonRowTemplateMatrixInput(Hop input1, Hop input2, CPlanMemoTable memo) {
        long tmp2;
        if (!memo.contains(input2.getHopID(), TemplateBase.TemplateType.ROW)) {
            return true;
        }
        long tmp1 = TemplateUtils.getRowTemplateMatrixInput(input1, memo);
        return tmp1 == (tmp2 = TemplateUtils.getRowTemplateMatrixInput(input2, memo));
    }

    public static long getRowTemplateMatrixInput(Hop current, CPlanMemoTable memo) {
        CPlanMemoTable.MemoTableEntry me = memo.getBest(current.getHopID(), TemplateBase.TemplateType.ROW);
        long ret = -1L;
        for (int i = 0; ret < 0L && i < current.getInput().size(); ++i) {
            Hop input = current.getInput().get(i);
            if (me.isPlanRef(i) && memo.contains(input.getHopID(), TemplateBase.TemplateType.ROW)) {
                ret = TemplateUtils.getRowTemplateMatrixInput(input, memo);
                continue;
            }
            if (me.isPlanRef(i) || !TemplateUtils.isMatrix(input)) continue;
            ret = input.getHopID();
        }
        return ret;
    }

    public static boolean containsBinary(CNode node, CNodeBinary.BinType type) {
        node.resetVisitStatus();
        boolean ret = TemplateUtils.rContainsBinary(node, type);
        node.resetVisitStatus();
        return ret;
    }

    public static boolean rContainsBinary(CNode node, CNodeBinary.BinType type) {
        if (node.isVisited()) {
            return false;
        }
        boolean ret = false;
        for (CNode input : node.getInput()) {
            ret |= TemplateUtils.rContainsBinary(input, type);
        }
        CNodeBinary.BinType[] binTypeArray = new CNodeBinary.BinType[]{type};
        node.setVisited();
        return ret |= TemplateUtils.isBinary(node, binTypeArray);
    }

    public static boolean containsOuterProduct(Hop hop) {
        hop.resetVisitStatus();
        boolean ret = TemplateUtils.rContainsOuterProduct(hop);
        hop.resetVisitStatus();
        return ret;
    }

    public static boolean containsOuterProduct(Hop hop, Hop probe) {
        hop.resetVisitStatus();
        boolean ret = TemplateUtils.rContainsOuterProduct(hop, probe);
        hop.resetVisitStatus();
        return ret;
    }

    private static boolean rContainsOuterProduct(Hop current) {
        if (current.isVisited()) {
            return false;
        }
        boolean ret = false;
        ret |= HopRewriteUtils.isOuterProductLikeMM(current);
        for (int i = 0; i < current.getInput().size() && !ret; ret |= TemplateUtils.rContainsOuterProduct(current.getInput().get(i)), ++i) {
        }
        current.setVisited();
        return ret;
    }

    private static boolean rContainsOuterProduct(Hop current, Hop probe) {
        if (current.isVisited()) {
            return false;
        }
        boolean ret = false;
        ret |= HopRewriteUtils.isOuterProductLikeMM(current) && TemplateUtils.checkContainment(current.getInput(), probe, true);
        for (int i = 0; i < current.getInput().size() && !ret; ret |= TemplateUtils.rContainsOuterProduct(current.getInput().get(i), probe), ++i) {
        }
        current.setVisited();
        return ret;
    }

    private static boolean checkContainment(ArrayList<Hop> inputs, Hop probe, boolean inclTranspose) {
        if (!inclTranspose) {
            return inputs.contains(probe);
        }
        for (Hop hop : inputs) {
            if (!HopRewriteUtils.isTransposeOfItself(hop, probe)) continue;
            return true;
        }
        return false;
    }

    public static void rFlipVectorLookups(CNode current) {
        if (TemplateUtils.isUnary(current, CNodeUnary.UnaryType.LOOKUP_C)) {
            ((CNodeUnary)current).setType(CNodeUnary.UnaryType.LOOKUP_R);
        } else if (TemplateUtils.isUnary(current, CNodeUnary.UnaryType.LOOKUP_R)) {
            ((CNodeUnary)current).setType(CNodeUnary.UnaryType.LOOKUP_C);
        }
        for (CNode input : current.getInput()) {
            TemplateUtils.rFlipVectorLookups(input);
        }
    }
}

