/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.recompile;

import java.util.ArrayList;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.utils.Statistics;

public class LiteralReplacement {
    private static final long REPLACE_LITERALS_MAX_MATRIX_SIZE = 1000000L;
    private static final boolean REPORT_LITERAL_REPLACE_OPS_STATS = true;

    protected static void rReplaceLiterals(Hop hop, LocalVariableMap vars, boolean scalarsOnly) throws DMLRuntimeException {
        if (hop.isVisited()) {
            return;
        }
        if (hop.getInput() != null) {
            for (int i = 0; i < hop.getInput().size(); ++i) {
                Hop c = hop.getInput().get(i);
                LiteralOp lit = null;
                lit = lit == null ? LiteralReplacement.replaceLiteralScalarRead(c, vars) : lit;
                lit = lit == null ? LiteralReplacement.replaceLiteralValueTypeCastScalarRead(c, vars) : lit;
                LiteralOp literalOp = lit = lit == null ? LiteralReplacement.replaceLiteralValueTypeCastLiteral(c, vars) : lit;
                if (!scalarsOnly) {
                    lit = lit == null ? LiteralReplacement.replaceLiteralDataTypeCastMatrixRead(c, vars) : lit;
                    lit = lit == null ? LiteralReplacement.replaceLiteralValueTypeCastRightIndexing(c, vars) : lit;
                    lit = lit == null ? LiteralReplacement.replaceLiteralFullUnaryAggregate(c, vars) : lit;
                    LiteralOp literalOp2 = lit = lit == null ? LiteralReplacement.replaceLiteralFullUnaryAggregateRightIndexing(c, vars) : lit;
                }
                if (lit != null) {
                    if (c.getParent().size() > 1) {
                        ArrayList<Hop> parents = new ArrayList<Hop>(c.getParent());
                        for (Hop p : parents) {
                            int pos = HopRewriteUtils.getChildReferencePos(p, c);
                            HopRewriteUtils.removeChildReferenceByPos(p, c, pos);
                            HopRewriteUtils.addChildReference(p, lit, pos);
                        }
                        continue;
                    }
                    HopRewriteUtils.replaceChildReference(hop, c, lit, i);
                    continue;
                }
                LiteralReplacement.rReplaceLiterals(c, vars, scalarsOnly);
            }
        }
        hop.setVisited();
    }

    private static LiteralOp replaceLiteralScalarRead(Hop c, LocalVariableMap vars) {
        Data dat;
        LiteralOp ret = null;
        if (c instanceof DataOp && ((DataOp)c).getDataOpType() != Hop.DataOpTypes.PERSISTENTREAD && c.getDataType() == Expression.DataType.SCALAR && (dat = vars.get(c.getName())) != null) {
            ScalarObject sdat = (ScalarObject)dat;
            switch (sdat.getValueType()) {
                case INT: {
                    ret = new LiteralOp(sdat.getLongValue());
                    break;
                }
                case DOUBLE: {
                    ret = new LiteralOp(sdat.getDoubleValue());
                    break;
                }
                case BOOLEAN: {
                    ret = new LiteralOp(sdat.getBooleanValue());
                    break;
                }
            }
        }
        return ret;
    }

    private static LiteralOp replaceLiteralValueTypeCastScalarRead(Hop c, LocalVariableMap vars) {
        Data dat;
        LiteralOp ret = null;
        if (c instanceof UnaryOp && (((UnaryOp)c).getOp() == Hop.OpOp1.CAST_AS_DOUBLE || ((UnaryOp)c).getOp() == Hop.OpOp1.CAST_AS_INT || ((UnaryOp)c).getOp() == Hop.OpOp1.CAST_AS_BOOLEAN) && c.getInput().get(0) instanceof DataOp && c.getDataType() == Expression.DataType.SCALAR && (dat = vars.get(c.getInput().get(0).getName())) != null) {
            ScalarObject sdat = (ScalarObject)dat;
            UnaryOp cast = (UnaryOp)c;
            switch (cast.getOp()) {
                case CAST_AS_INT: {
                    ret = new LiteralOp(sdat.getLongValue());
                    break;
                }
                case CAST_AS_DOUBLE: {
                    ret = new LiteralOp(sdat.getDoubleValue());
                    break;
                }
                case CAST_AS_BOOLEAN: {
                    ret = new LiteralOp(sdat.getBooleanValue());
                    break;
                }
            }
        }
        return ret;
    }

    private static LiteralOp replaceLiteralValueTypeCastLiteral(Hop c, LocalVariableMap vars) throws DMLRuntimeException {
        LiteralOp ret = null;
        if (c instanceof UnaryOp && (((UnaryOp)c).getOp() == Hop.OpOp1.CAST_AS_DOUBLE || ((UnaryOp)c).getOp() == Hop.OpOp1.CAST_AS_INT || ((UnaryOp)c).getOp() == Hop.OpOp1.CAST_AS_BOOLEAN) && c.getInput().get(0) instanceof LiteralOp) {
            LiteralOp sdat = (LiteralOp)c.getInput().get(0);
            UnaryOp cast = (UnaryOp)c;
            try {
                switch (cast.getOp()) {
                    case CAST_AS_INT: {
                        long ival = HopRewriteUtils.getIntValue(sdat);
                        ret = new LiteralOp(ival);
                        break;
                    }
                    case CAST_AS_DOUBLE: {
                        double dval = HopRewriteUtils.getDoubleValue(sdat);
                        ret = new LiteralOp(dval);
                        break;
                    }
                    case CAST_AS_BOOLEAN: {
                        boolean bval = HopRewriteUtils.getBooleanValue(sdat);
                        ret = new LiteralOp(bval);
                        break;
                    }
                }
            }
            catch (HopsException ex) {
                throw new DMLRuntimeException(ex);
            }
        }
        return ret;
    }

    private static LiteralOp replaceLiteralDataTypeCastMatrixRead(Hop c, LocalVariableMap vars) throws DMLRuntimeException {
        Data dat;
        LiteralOp ret = null;
        if (c instanceof UnaryOp && ((UnaryOp)c).getOp() == Hop.OpOp1.CAST_AS_SCALAR && c.getInput().get(0) instanceof DataOp && c.getInput().get(0).getDataType() == Expression.DataType.MATRIX && (dat = vars.get(c.getInput().get(0).getName())) != null) {
            MatrixObject mo = (MatrixObject)dat;
            MatrixBlock mBlock = (MatrixBlock)mo.acquireRead();
            if (mBlock.getNumRows() != 1 || mBlock.getNumColumns() != 1) {
                throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix of dimension (" + mBlock.getNumRows() + " x " + mBlock.getNumColumns() + ") to scalar.");
            }
            double value = mBlock.getValue(0, 0);
            mo.release();
            ret = new LiteralOp(value);
        }
        return ret;
    }

    private static LiteralOp replaceLiteralValueTypeCastRightIndexing(Hop c, LocalVariableMap vars) throws DMLRuntimeException {
        LiteralOp ret = null;
        if (c instanceof UnaryOp && ((UnaryOp)c).getOp() == Hop.OpOp1.CAST_AS_SCALAR && c.getInput().get(0) instanceof IndexingOp && c.getInput().get(0).getDataType() == Expression.DataType.MATRIX) {
            IndexingOp rix = (IndexingOp)c.getInput().get(0);
            Hop data = rix.getInput().get(0);
            Hop rl = rix.getInput().get(1);
            Hop ru = rix.getInput().get(2);
            Hop cl = rix.getInput().get(3);
            Hop cu = rix.getInput().get(4);
            if (rix.dimsKnown() && rix.getDim1() == 1L && rix.getDim2() == 1L && data instanceof DataOp && vars.keySet().contains(data.getName()) && LiteralReplacement.isIntValueDataLiteral(rl, vars) && LiteralReplacement.isIntValueDataLiteral(ru, vars) && LiteralReplacement.isIntValueDataLiteral(cl, vars) && LiteralReplacement.isIntValueDataLiteral(cu, vars)) {
                long rlval = LiteralReplacement.getIntValueDataLiteral(rl, vars);
                long clval = LiteralReplacement.getIntValueDataLiteral(cl, vars);
                MatrixObject mo = (MatrixObject)vars.get(data.getName());
                if (mo.getNumRows() * mo.getNumColumns() < 1000000L) {
                    MatrixBlock mBlock = (MatrixBlock)mo.acquireRead();
                    double value = mBlock.getValue((int)rlval - 1, (int)clval - 1);
                    mo.release();
                    ret = new LiteralOp(value);
                }
            }
        }
        return ret;
    }

    private static LiteralOp replaceLiteralFullUnaryAggregate(Hop c, LocalVariableMap vars) throws DMLRuntimeException {
        Hop data;
        MatrixObject mo;
        LiteralOp ret = null;
        if (c instanceof AggUnaryOp && LiteralReplacement.isReplaceableUnaryAggregate((AggUnaryOp)c) && c.getInput().get(0) instanceof DataOp && vars.keySet().contains(c.getInput().get(0).getName()) && (mo = (MatrixObject)vars.get((data = c.getInput().get(0)).getName())).getNumRows() * mo.getNumColumns() < 1000000L) {
            MatrixBlock mBlock = (MatrixBlock)mo.acquireRead();
            double value = LiteralReplacement.replaceUnaryAggregate((AggUnaryOp)c, mBlock);
            mo.release();
            ret = new LiteralOp(value);
        }
        return ret;
    }

    private static LiteralOp replaceLiteralFullUnaryAggregateRightIndexing(Hop c, LocalVariableMap vars) throws DMLRuntimeException {
        LiteralOp ret = null;
        if (c instanceof AggUnaryOp && LiteralReplacement.isReplaceableUnaryAggregate((AggUnaryOp)c) && c.getInput().get(0) instanceof IndexingOp && c.getInput().get(0).getInput().get(0) instanceof DataOp) {
            IndexingOp rix = (IndexingOp)c.getInput().get(0);
            Hop data = rix.getInput().get(0);
            Hop rl = rix.getInput().get(1);
            Hop ru = rix.getInput().get(2);
            Hop cl = rix.getInput().get(3);
            Hop cu = rix.getInput().get(4);
            if (data instanceof DataOp && vars.keySet().contains(data.getName()) && LiteralReplacement.isIntValueDataLiteral(rl, vars) && LiteralReplacement.isIntValueDataLiteral(ru, vars) && LiteralReplacement.isIntValueDataLiteral(cl, vars) && LiteralReplacement.isIntValueDataLiteral(cu, vars)) {
                long rlval = LiteralReplacement.getIntValueDataLiteral(rl, vars);
                long ruval = LiteralReplacement.getIntValueDataLiteral(ru, vars);
                long clval = LiteralReplacement.getIntValueDataLiteral(cl, vars);
                long cuval = LiteralReplacement.getIntValueDataLiteral(cu, vars);
                MatrixObject mo = (MatrixObject)vars.get(data.getName());
                if (mo.getNumRows() * mo.getNumColumns() < 1000000L) {
                    MatrixBlock mBlock = (MatrixBlock)mo.acquireRead();
                    MatrixBlock mBlock2 = mBlock.sliceOperations((int)(rlval - 1L), (int)(ruval - 1L), (int)(clval - 1L), (int)(cuval - 1L), new MatrixBlock());
                    double value = LiteralReplacement.replaceUnaryAggregate((AggUnaryOp)c, mBlock2);
                    mo.release();
                    ret = new LiteralOp(value);
                }
            }
        }
        return ret;
    }

    private static boolean isIntValueDataLiteral(Hop h, LocalVariableMap vars) {
        return h instanceof DataOp && vars.keySet().contains(h.getName()) || h instanceof LiteralOp || h instanceof UnaryOp && (((UnaryOp)h).getOp() == Hop.OpOp1.NROW || ((UnaryOp)h).getOp() == Hop.OpOp1.NCOL) && h.getInput().get(0) instanceof DataOp && vars.keySet().contains(h.getInput().get(0).getName());
    }

    private static long getIntValueDataLiteral(Hop hop, LocalVariableMap vars) throws DMLRuntimeException {
        long value = -1L;
        try {
            if (hop instanceof LiteralOp) {
                value = HopRewriteUtils.getIntValue((LiteralOp)hop);
            } else if (hop instanceof UnaryOp && ((UnaryOp)hop).getOp() == Hop.OpOp1.NROW) {
                MatrixObject mo = (MatrixObject)vars.get(hop.getInput().get(0).getName());
                value = mo.getNumRows();
            } else if (hop instanceof UnaryOp && ((UnaryOp)hop).getOp() == Hop.OpOp1.NCOL) {
                MatrixObject mo = (MatrixObject)vars.get(hop.getInput().get(0).getName());
                value = mo.getNumColumns();
            } else {
                ScalarObject sdat = (ScalarObject)vars.get(hop.getName());
                value = sdat.getLongValue();
            }
        }
        catch (HopsException ex) {
            throw new DMLRuntimeException("Failed to get int value for literal replacement", ex);
        }
        return value;
    }

    private static boolean isReplaceableUnaryAggregate(AggUnaryOp auop) {
        boolean cdir = auop.getDirection() == Hop.Direction.RowCol;
        boolean cop = auop.getOp() == Hop.AggOp.SUM || auop.getOp() == Hop.AggOp.SUM_SQ || auop.getOp() == Hop.AggOp.MIN || auop.getOp() == Hop.AggOp.MAX;
        return cdir && cop;
    }

    private static double replaceUnaryAggregate(AggUnaryOp auop, MatrixBlock mb) throws DMLRuntimeException {
        boolean REPORT_STATS = DMLScript.STATISTICS;
        long t0 = REPORT_STATS ? System.nanoTime() : 0L;
        double val = Double.MAX_VALUE;
        switch (auop.getOp()) {
            case SUM: {
                val = mb.sum();
                break;
            }
            case SUM_SQ: {
                val = mb.sumSq();
                break;
            }
            case MIN: {
                val = mb.min();
                break;
            }
            case MAX: {
                val = mb.max();
                break;
            }
            default: {
                throw new DMLRuntimeException("Unsupported unary aggregate replacement: " + (Object)((Object)auop.getOp()));
            }
        }
        if (REPORT_STATS) {
            long t1 = System.nanoTime();
            Statistics.maintainCPHeavyHitters("rlit", t1 - t0);
        }
        return val;
    }
}

