/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.cplan;

import org.apache.commons.lang.StringUtils;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.util.UtilFunctions;

public class CNodeUnary
extends CNode {
    private UnaryType _type;

    public CNodeUnary(CNode in1, UnaryType type) {
        this._inputs.add(in1);
        this._type = type;
        this.setOutputDims();
    }

    public UnaryType getType() {
        return this._type;
    }

    public void setType(UnaryType type) {
        this._type = type;
    }

    @Override
    public String codegen(boolean sparse) {
        if (this.isGenerated()) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        sb.append(((CNode)this._inputs.get(0)).codegen(sparse));
        boolean lsparse = sparse && this._inputs.get(0) instanceof CNodeData && !((CNode)this._inputs.get(0)).getVarname().startsWith("b") && !((CNode)this._inputs.get(0)).isLiteral();
        String var = this.createVarname();
        String tmp = this._type.getTemplate(lsparse);
        tmp = tmp.replace("%TMP%", var);
        String varj = ((CNode)this._inputs.get(0)).getVarname();
        tmp = tmp.replace("%IN1v%", varj + "vals");
        tmp = tmp.replace("%IN1i%", varj + "ix");
        tmp = tmp.replace("%IN1%", varj.startsWith("b") && !this._type.isScalarLookup() && TemplateUtils.isMatrix((CNode)this._inputs.get(0)) ? varj + ".ddat" : varj);
        String spos = this._inputs.get(0) instanceof CNodeData && ((CNode)this._inputs.get(0)).getDataType().isMatrix() ? (!varj.startsWith("b") ? varj + "i" : (TemplateUtils.isMatrix((CNode)this._inputs.get(0)) ? "rowIndex*%LEN%" : "0")) : "0";
        tmp = tmp.replace("%POS1%", spos);
        tmp = tmp.replace("%POS2%", spos);
        if (((CNode)this._inputs.get(0)).getDataType().isMatrix()) {
            tmp = tmp.replace("%LEN%", ((CNode)this._inputs.get(0)).getVectorLength());
        }
        sb.append(tmp);
        this._generated = true;
        return sb.toString();
    }

    public String toString() {
        switch (this._type) {
            case ROW_SUMS: {
                return "u(R+)";
            }
            case ROW_MINS: {
                return "u(Rmin)";
            }
            case ROW_MAXS: {
                return "u(Rmax)";
            }
            case VECT_EXP: 
            case VECT_POW2: 
            case VECT_MULT2: 
            case VECT_SQRT: 
            case VECT_LOG: 
            case VECT_ABS: 
            case VECT_ROUND: 
            case VECT_CEIL: 
            case VECT_FLOOR: 
            case VECT_SIGN: 
            case VECT_CUMSUM: 
            case VECT_CUMMIN: 
            case VECT_CUMMAX: {
                return "u(v" + this._type.name().toLowerCase() + ")";
            }
            case LOOKUP_R: {
                return "u(ixr)";
            }
            case LOOKUP_C: {
                return "u(ixc)";
            }
            case LOOKUP_RC: {
                return "u(ixrc)";
            }
            case LOOKUP0: {
                return "u(ix0)";
            }
            case CBIND0: {
                return "u(cbind0)";
            }
            case POW2: {
                return "^2";
            }
        }
        return "u(" + this._type.name().toLowerCase() + ")";
    }

    @Override
    public void setOutputDims() {
        switch (this._type) {
            case VECT_EXP: 
            case VECT_POW2: 
            case VECT_MULT2: 
            case VECT_SQRT: 
            case VECT_LOG: 
            case VECT_ABS: 
            case VECT_ROUND: 
            case VECT_CEIL: 
            case VECT_FLOOR: 
            case VECT_SIGN: 
            case VECT_CUMSUM: 
            case VECT_CUMMIN: 
            case VECT_CUMMAX: {
                this._rows = ((CNode)this._inputs.get((int)0))._rows;
                this._cols = ((CNode)this._inputs.get((int)0))._cols;
                this._dataType = Expression.DataType.MATRIX;
                break;
            }
            case ROW_SUMS: 
            case ROW_MINS: 
            case ROW_MAXS: 
            case EXP: 
            case LOOKUP_R: 
            case LOOKUP_C: 
            case LOOKUP_RC: 
            case LOOKUP0: 
            case CBIND0: 
            case POW2: 
            case MULT2: 
            case ABS: 
            case SIN: 
            case COS: 
            case TAN: 
            case ASIN: 
            case ACOS: 
            case ATAN: 
            case SIGN: 
            case SQRT: 
            case LOG: 
            case ROUND: 
            case CEIL: 
            case FLOOR: 
            case SELP: 
            case SPROP: 
            case SIGMOID: 
            case LOG_NZ: {
                this._rows = 0L;
                this._cols = 0L;
                this._dataType = Expression.DataType.SCALAR;
                break;
            }
            default: {
                throw new RuntimeException("Operation " + this._type.toString() + " has no output dimensions, dimensions needs to be specified for the CNode ");
            }
        }
    }

    @Override
    public int hashCode() {
        if (this._hash == 0) {
            this._hash = UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode());
        }
        return this._hash;
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof CNodeUnary)) {
            return false;
        }
        CNodeUnary that = (CNodeUnary)o;
        return super.equals(that) && this._type == that._type;
    }

    public static enum UnaryType {
        LOOKUP_R,
        LOOKUP_C,
        LOOKUP_RC,
        LOOKUP0,
        CBIND0,
        ROW_SUMS,
        ROW_MINS,
        ROW_MAXS,
        VECT_EXP,
        VECT_POW2,
        VECT_MULT2,
        VECT_SQRT,
        VECT_LOG,
        VECT_ABS,
        VECT_ROUND,
        VECT_CEIL,
        VECT_FLOOR,
        VECT_SIGN,
        VECT_CUMSUM,
        VECT_CUMMIN,
        VECT_CUMMAX,
        EXP,
        POW2,
        MULT2,
        SQRT,
        LOG,
        LOG_NZ,
        ABS,
        ROUND,
        CEIL,
        FLOOR,
        SIGN,
        SIN,
        COS,
        TAN,
        ASIN,
        ACOS,
        ATAN,
        SELP,
        SPROP,
        SIGMOID;


        public static boolean contains(String value) {
            for (UnaryType ut : UnaryType.values()) {
                if (!ut.name().equals(value)) continue;
                return true;
            }
            return false;
        }

        public String getTemplate(boolean sparse) {
            switch (this) {
                case ROW_SUMS: 
                case ROW_MINS: 
                case ROW_MAXS: {
                    String vectName = StringUtils.capitalize((String)this.toString().substring(4, 7).toLowerCase());
                    return sparse ? "    double %TMP% = LibSpoofPrimitives.vect" + vectName + "(%IN1v%, %IN1i%, %POS1%, alen, len);\n" : "    double %TMP% = LibSpoofPrimitives.vect" + vectName + "(%IN1%, %POS1%, %LEN%);\n";
                }
                case VECT_EXP: 
                case VECT_POW2: 
                case VECT_MULT2: 
                case VECT_SQRT: 
                case VECT_LOG: 
                case VECT_ABS: 
                case VECT_ROUND: 
                case VECT_CEIL: 
                case VECT_FLOOR: 
                case VECT_SIGN: 
                case VECT_CUMSUM: 
                case VECT_CUMMIN: 
                case VECT_CUMMAX: {
                    String vectName = this.getVectorPrimitiveName();
                    return sparse ? "    double[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" : "    double[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %POS1%, %LEN%);\n";
                }
                case EXP: {
                    return "    double %TMP% = FastMath.exp(%IN1%);\n";
                }
                case LOOKUP_R: {
                    return "    double %TMP% = getValue(%IN1%, rowIndex);\n";
                }
                case LOOKUP_C: {
                    return "    double %TMP% = getValue(%IN1%, n, 0, colIndex);\n";
                }
                case LOOKUP_RC: {
                    return "    double %TMP% = getValue(%IN1%, n, rowIndex, colIndex);\n";
                }
                case LOOKUP0: {
                    return "    double %TMP% = %IN1%[0];\n";
                }
                case CBIND0: {
                    return "    double %TMP% = %IN1%; rowIndex *= 2;\n";
                }
                case POW2: {
                    return "    double %TMP% = %IN1% * %IN1%;\n";
                }
                case MULT2: {
                    return "    double %TMP% = %IN1% + %IN1%;\n";
                }
                case ABS: {
                    return "    double %TMP% = Math.abs(%IN1%);\n";
                }
                case SIN: {
                    return "    double %TMP% = FastMath.sin(%IN1%);\n";
                }
                case COS: {
                    return "    double %TMP% = FastMath.cos(%IN1%);\n";
                }
                case TAN: {
                    return "    double %TMP% = FastMath.tan(%IN1%);\n";
                }
                case ASIN: {
                    return "    double %TMP% = FastMath.asin(%IN1%);\n";
                }
                case ACOS: {
                    return "    double %TMP% = FastMath.acos(%IN1%);\n";
                }
                case ATAN: {
                    return "    double %TMP% = Math.atan(%IN1%);\n";
                }
                case SIGN: {
                    return "    double %TMP% = FastMath.signum(%IN1%);\n";
                }
                case SQRT: {
                    return "    double %TMP% = Math.sqrt(%IN1%);\n";
                }
                case LOG: {
                    return "    double %TMP% = FastMath.log(%IN1%);\n";
                }
                case ROUND: {
                    return "    double %TMP% = Math.round(%IN1%);\n";
                }
                case CEIL: {
                    return "    double %TMP% = FastMath.ceil(%IN1%);\n";
                }
                case FLOOR: {
                    return "    double %TMP% = FastMath.floor(%IN1%);\n";
                }
                case SELP: {
                    return "    double %TMP% = (%IN1%>0) ? %IN1% : 0;\n";
                }
                case SPROP: {
                    return "    double %TMP% = %IN1% * (1 - %IN1%);\n";
                }
                case SIGMOID: {
                    return "    double %TMP% = 1 / (1 + FastMath.exp(-%IN1%));\n";
                }
                case LOG_NZ: {
                    return "    double %TMP% = (%IN1%==0) ? 0 : FastMath.log(%IN1%);\n";
                }
            }
            throw new RuntimeException("Invalid unary type: " + this.toString());
        }

        public boolean isVectorScalarPrimitive() {
            return this == VECT_EXP || this == VECT_POW2 || this == VECT_MULT2 || this == VECT_SQRT || this == VECT_LOG || this == VECT_ABS || this == VECT_ROUND || this == VECT_CEIL || this == VECT_FLOOR || this == VECT_SIGN || this == VECT_CUMSUM || this == VECT_CUMMIN || this == VECT_CUMMAX;
        }

        public UnaryType getVectorAddPrimitive() {
            return UnaryType.valueOf("VECT_" + this.getVectorPrimitiveName().toUpperCase() + "_ADD");
        }

        public String getVectorPrimitiveName() {
            String[] tmp = this.name().split("_");
            return StringUtils.capitalize((String)tmp[1].toLowerCase());
        }

        public boolean isScalarLookup() {
            return this == LOOKUP0 || this == LOOKUP_R || this == LOOKUP_C || this == LOOKUP_RC;
        }
    }
}

