/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.udf.lib;

import java.io.IOException;
import java.util.Iterator;
import java.util.Random;
import org.apache.sysml.runtime.controlprogram.caching.CacheException;
import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.udf.FunctionParameter;
import org.apache.sysml.udf.Matrix;
import org.apache.sysml.udf.PackageFunction;
import org.apache.sysml.udf.Scalar;

public class SGDNesterovUpdate
extends PackageFunction {
    private static final long serialVersionUID = -3905212831582648882L;
    private Matrix updatedX;
    private Matrix updatedV;
    private Random rand = new Random();

    @Override
    public int getNumFunctionOutputs() {
        return 2;
    }

    @Override
    public FunctionParameter getFunctionOutput(int pos) {
        if (pos == 0) {
            return this.updatedX;
        }
        if (pos == 1) {
            return this.updatedV;
        }
        throw new RuntimeException("Invalid function output being requested");
    }

    boolean isDense(MatrixBlock X) {
        return !X.isInSparseFormat() && X.getDenseBlock() != null;
    }

    @Override
    public void execute() {
        try {
            int nnz;
            double[] XArr;
            MatrixBlock X = (MatrixBlock)((Matrix)this.getFunctionInput(0)).getMatrixObject().acquireRead();
            MatrixBlock dX = (MatrixBlock)((Matrix)this.getFunctionInput(1)).getMatrixObject().acquireRead();
            double lr = Double.parseDouble(((Scalar)this.getFunctionInput(2)).getValue());
            double mu = Double.parseDouble(((Scalar)this.getFunctionInput(3)).getValue());
            MatrixBlock v = (MatrixBlock)((Matrix)this.getFunctionInput(4)).getMatrixObject().acquireRead();
            double lambda = Double.parseDouble(((Scalar)this.getFunctionInput(5)).getValue());
            this.updatedV = new Matrix("tmp_" + this.rand.nextLong(), v.getNumRows(), v.getNumColumns(), Matrix.ValueType.Double);
            MatrixBlock updatedVMB = this.allocateDenseMatrixBlock(this.updatedV);
            double[] updatedVData = updatedVMB.getDenseBlock();
            if (this.isDense(v) && this.isDense(dX) && this.isDense(X)) {
                double[] vArr = v.getDenseBlock();
                double[] dXArr = dX.getDenseBlock();
                XArr = X.getDenseBlock();
                nnz = 0;
                for (int i = 0; i < updatedVData.length; ++i) {
                    updatedVData[i] = mu * vArr[i] - lr * dXArr[i] - lr * lambda * XArr[i];
                    nnz += updatedVData[i] != 0.0 ? 1 : 0;
                }
                updatedVMB.setNonZeros(nnz);
            } else {
                this.multiplyByConstant(v, mu, updatedVData);
                this.multiplyByConstant(dX, -lr, updatedVData);
                this.multiplyByConstant(X, -lr * lambda, updatedVData);
                updatedVMB.recomputeNonZeros();
            }
            this.updatedV.setMatrixDoubleArray(updatedVMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
            this.updatedX = new Matrix("tmp_" + this.rand.nextLong(), X.getNumRows(), X.getNumColumns(), Matrix.ValueType.Double);
            MatrixBlock updatedXMB = this.allocateDenseMatrixBlock(this.updatedX);
            double[] updatedXData = updatedXMB.getDenseBlock();
            if (this.isDense(X) && this.isDense(v)) {
                XArr = X.getDenseBlock();
                double[] vPrevArr = v.getDenseBlock();
                int nnz2 = 0;
                double muPlus1 = mu + 1.0;
                for (int i = 0; i < updatedXData.length; ++i) {
                    updatedXData[i] = XArr[i] - mu * vPrevArr[i] + muPlus1 * updatedVData[i];
                    nnz2 += updatedXData[i] != 0.0 ? 1 : 0;
                }
                updatedXMB.setNonZeros(nnz2);
            } else if (this.isDense(v)) {
                this.copy(X, updatedXData);
                double[] vPrevArr = v.getDenseBlock();
                nnz = 0;
                double muPlus1 = mu + 1.0;
                for (int i = 0; i < updatedXData.length; ++i) {
                    int n = i;
                    updatedXData[n] = updatedXData[n] + (-mu * vPrevArr[i] + muPlus1 * updatedVData[i]);
                    nnz += updatedXData[i] != 0.0 ? 1 : 0;
                }
                updatedXMB.setNonZeros(nnz);
            } else {
                this.copy(X, updatedXData);
                this.multiplyByConstant(v, -mu, updatedXData);
                this.multiplyByConstant(updatedVData, 1.0 + mu, updatedXData);
                updatedXMB.recomputeNonZeros();
            }
            this.updatedX.setMatrixDoubleArray(updatedXMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
            ((Matrix)this.getFunctionInput(0)).getMatrixObject().release();
            ((Matrix)this.getFunctionInput(1)).getMatrixObject().release();
            ((Matrix)this.getFunctionInput(4)).getMatrixObject().release();
        }
        catch (CacheException e) {
            throw new RuntimeException("Exception while executing SGDNesterovUpdate", e);
        }
        catch (IOException e) {
            throw new RuntimeException("Exception while executing SGDNesterovUpdate", e);
        }
    }

    private MatrixBlock allocateDenseMatrixBlock(Matrix mat) {
        int rows = (int)mat.getNumRows();
        int cols = (int)mat.getNumCols();
        MatrixBlock mb = new MatrixBlock(rows, cols, false);
        mb.allocateDenseBlock();
        return mb;
    }

    private void multiplyByConstant(double[] in, double constant, double[] out) {
        for (int i = 0; i < out.length; ++i) {
            int n = i;
            out[n] = out[n] + in[i] * constant;
        }
    }

    private void multiplyByConstant(MatrixBlock in, double constant, double[] out) {
        block3: {
            block2: {
                if (!in.isInSparseFormat()) break block2;
                Iterator<IJV> iter = in.getSparseBlockIterator();
                while (iter.hasNext()) {
                    IJV ijv = iter.next();
                    int n = ijv.getI() * ijv.getJ();
                    out[n] = out[n] + ijv.getV() * constant;
                }
                break block3;
            }
            double[] denseBlock = in.getDenseBlock();
            if (denseBlock == null) break block3;
            for (int i = 0; i < out.length; ++i) {
                int n = i;
                out[n] = out[n] + denseBlock[i] * constant;
            }
        }
    }

    private void copy(MatrixBlock src, double[] dest) {
        if (src.isInSparseFormat()) {
            Iterator<IJV> iter = src.getSparseBlockIterator();
            while (iter.hasNext()) {
                IJV ijv = iter.next();
                dest[ijv.getI() * ijv.getJ()] = ijv.getV();
            }
        } else {
            double[] denseBlock = src.getDenseBlock();
            if (denseBlock != null) {
                System.arraycopy(denseBlock, 0, dest, 0, dest.length);
            }
        }
    }
}

