/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.scripts.nn.layers;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.scripts.nn.layers.batch_norm2d.Backward_output;
import org.apache.sysml.scripts.nn.layers.batch_norm2d.Forward_output;
import org.apache.sysml.scripts.nn.layers.batch_norm2d.Init_output;

public class Batch_norm2d
extends Script {
    public Batch_norm2d() {
        String string = "scripts/nn/layers/batch_norm2d.dml";
        InputStream inputStream = Script.class.getResourceAsStream(new StringBuffer().append("/").append(string).toString());
        InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
        char[] cArray = new char[1024];
        StringBuilder stringBuilder = new StringBuilder();
        try {
            int n;
            while ((n = inputStreamReader.read(cArray)) > 0) {
                stringBuilder.append(cArray, 0, n);
            }
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
        this.setScriptString(stringBuilder.toString());
    }

    public Init_output init(Object object) {
        String string = "source('scripts/nn/layers/batch_norm2d.dml') as mlcontextns;[gamma, beta, ema_mean, ema_var] = mlcontextns::init(C);";
        Script script = new Script(string);
        script.in("C", object).out("gamma").out("beta").out("ema_mean").out("ema_var");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("gamma");
        Matrix matrix2 = mLResults.getMatrix("beta");
        Matrix matrix3 = mLResults.getMatrix("ema_mean");
        Matrix matrix4 = mLResults.getMatrix("ema_var");
        Init_output init_output = new Init_output(matrix, matrix2, matrix3, matrix4);
        return init_output;
    }

    public String init__docs() {
        String string = "init = function(int C)\n    return (matrix[double] gamma, matrix[double] beta,\n            matrix[double] ema_mean, matrix[double] ema_var) {\n  /*\n   * Initialize the parameters of this layer.\n   *\n   * Note: This is just a convenience function, and parameters\n   * may be initialized manually if needed.\n   *\n   * Inputs:\n   *  - C: Number of input channels (dimensionality of input depth).\n   *\n   * Outputs:\n   *  - gamma: Scale parameters, of shape (C, 1).\n   *  - beta: Shift parameters, of shape (C, 1).\n   *  - ema_mean: Exponential moving average of the mean, of\n   *      shape (C, 1).\n   *  - ema_var: Exponential moving average of the variance, of\n   *      shape (C, 1).\n   */\n";
        return string;
    }

    public String init__source() {
        String string = "init = function(int C)\n    return (matrix[double] gamma, matrix[double] beta,\n            matrix[double] ema_mean, matrix[double] ema_var) {\n  /*\n   * Initialize the parameters of this layer.\n   *\n   * Note: This is just a convenience function, and parameters\n   * may be initialized manually if needed.\n   *\n   * Inputs:\n   *  - C: Number of input channels (dimensionality of input depth).\n   *\n   * Outputs:\n   *  - gamma: Scale parameters, of shape (C, 1).\n   *  - beta: Shift parameters, of shape (C, 1).\n   *  - ema_mean: Exponential moving average of the mean, of\n   *      shape (C, 1).\n   *  - ema_var: Exponential moving average of the variance, of\n   *      shape (C, 1).\n   */\n   gamma = matrix(1, rows=C, cols=1)\n   beta = matrix(0, rows=C, cols=1)\n   ema_mean = matrix(0, rows=C, cols=1)\n   ema_var = matrix(1, rows=C, cols=1)\n}\n";
        return string;
    }

    public Forward_output forward(Object object, Object object2, Object object3, Object object4, Object object5, Object object6, Object object7, Object object8, Object object9, Object object10, Object object11) {
        String string = "source('scripts/nn/layers/batch_norm2d.dml') as mlcontextns;[out, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] = mlcontextns::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, epsilon);";
        Script script = new Script(string);
        script.in("X", object).in("gamma", object2).in("beta", object3).in("C", object4).in("Hin", object5).in("Win", object6).in("mode", object7).in("ema_mean", object8).in("ema_var", object9).in("mu", object10).in("epsilon", object11).out("out").out("ema_mean_upd").out("ema_var_upd").out("cache_mean").out("cache_var").out("cache_norm");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("out");
        Matrix matrix2 = mLResults.getMatrix("ema_mean_upd");
        Matrix matrix3 = mLResults.getMatrix("ema_var_upd");
        Matrix matrix4 = mLResults.getMatrix("cache_mean");
        Matrix matrix5 = mLResults.getMatrix("cache_var");
        Matrix matrix6 = mLResults.getMatrix("cache_norm");
        Forward_output forward_output = new Forward_output(matrix, matrix2, matrix3, matrix4, matrix5, matrix6);
        return forward_output;
    }

    public String forward__docs() {
        String string = "forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,\n                   int C, int Hin, int Win, string mode,\n                   matrix[double] ema_mean, matrix[double] ema_var,\n                   double mu, double epsilon)\n    return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] ema_var_upd,\n            matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) {\n  /*\n   * Computes the forward pass for a 2D (spatial) batch normalization\n   * layer.  The input data has N examples, each represented as a 3D\n   * volume unrolled into a single vector.\n   *\n   * A spatial batch normalization layer uses the per-channel sample\n   * mean and per-channel uncorrected sample variance during training\n   * to normalize each channel of the input data.  Additionally, it\n   * introduces learnable parameters (gamma, beta) to control the\n   * amount of normalization.\n   *\n   *   `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`\n   *\n   * This implementation maintains exponential moving averages of the\n   * mean and variance during training for use during testing.\n   *\n   * Reference:\n   *  - Batch Normalization: Accelerating Deep Network Training by\n   *    Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015\n   *    - https://arxiv.org/abs/1502.03167\n   *\n   * Inputs:\n   *  - X: Inputs, of shape (N, C*Hin*Win).\n   *  - gamma: Scale parameters, of shape (C, 1).\n   *  - beta: Shift parameters, of shape (C, 1).\n   *  - C: Number of input channels (dimensionality of input depth).\n   *  - Hin: Input height.\n   *  - Win: Input width.\n   *  - mode: 'train' or 'test' to indicate if the model is currently\n   *      being trained or tested.  During training, the current batch\n   *      mean and variance will be used to normalize the inputs, while\n   *      during testing, the exponential average of the mean and\n   *      variance over all previous batches will be used.\n   *  - ema_mean: Exponential moving average of the mean, of\n   *      shape (C, 1).\n   *  - ema_var: Exponential moving average of the variance, of\n   *      shape (C, 1).\n   *  - mu: Momentum value for moving averages.\n   *      Typical values are in the range of [0.9, 0.999].\n   *  - epsilon: Smoothing term to avoid divide by zero errors.\n   *      Typical values are in the range of [1e-5, 1e-3].\n   *\n   * Outputs:\n   *  - out: Outputs, of shape (N, C*Hin*Win).\n   *  - ema_mean_upd: Updated exponential moving average of the mean,\n   *      of shape (C, 1).\n   *  - ema_var_upd: Updated exponential moving average of the variance,\n   *      of shape (C, 1).\n   *  - cache_mean: Cache of the batch mean, of shape (C, 1).\n   *      Note: This is used for performance during training.\n   *  - cache_var: Cache of the batch variance, of shape (C, 1).\n   *      Note: This is used for performance during training.\n   *  - cache_norm: Cache of the normalized inputs, of\n   *      shape (C, N*Hin*Win). Note: This is used for performance\n   *      during training.\n   */\n";
        return string;
    }

    public String forward__source() {
        String string = "forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,\n                   int C, int Hin, int Win, string mode,\n                   matrix[double] ema_mean, matrix[double] ema_var,\n                   double mu, double epsilon)\n    return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] ema_var_upd,\n            matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) {\n  /*\n   * Computes the forward pass for a 2D (spatial) batch normalization\n   * layer.  The input data has N examples, each represented as a 3D\n   * volume unrolled into a single vector.\n   *\n   * A spatial batch normalization layer uses the per-channel sample\n   * mean and per-channel uncorrected sample variance during training\n   * to normalize each channel of the input data.  Additionally, it\n   * introduces learnable parameters (gamma, beta) to control the\n   * amount of normalization.\n   *\n   *   `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`\n   *\n   * This implementation maintains exponential moving averages of the\n   * mean and variance during training for use during testing.\n   *\n   * Reference:\n   *  - Batch Normalization: Accelerating Deep Network Training by\n   *    Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015\n   *    - https://arxiv.org/abs/1502.03167\n   *\n   * Inputs:\n   *  - X: Inputs, of shape (N, C*Hin*Win).\n   *  - gamma: Scale parameters, of shape (C, 1).\n   *  - beta: Shift parameters, of shape (C, 1).\n   *  - C: Number of input channels (dimensionality of input depth).\n   *  - Hin: Input height.\n   *  - Win: Input width.\n   *  - mode: 'train' or 'test' to indicate if the model is currently\n   *      being trained or tested.  During training, the current batch\n   *      mean and variance will be used to normalize the inputs, while\n   *      during testing, the exponential average of the mean and\n   *      variance over all previous batches will be used.\n   *  - ema_mean: Exponential moving average of the mean, of\n   *      shape (C, 1).\n   *  - ema_var: Exponential moving average of the variance, of\n   *      shape (C, 1).\n   *  - mu: Momentum value for moving averages.\n   *      Typical values are in the range of [0.9, 0.999].\n   *  - epsilon: Smoothing term to avoid divide by zero errors.\n   *      Typical values are in the range of [1e-5, 1e-3].\n   *\n   * Outputs:\n   *  - out: Outputs, of shape (N, C*Hin*Win).\n   *  - ema_mean_upd: Updated exponential moving average of the mean,\n   *      of shape (C, 1).\n   *  - ema_var_upd: Updated exponential moving average of the variance,\n   *      of shape (C, 1).\n   *  - cache_mean: Cache of the batch mean, of shape (C, 1).\n   *      Note: This is used for performance during training.\n   *  - cache_var: Cache of the batch variance, of shape (C, 1).\n   *      Note: This is used for performance during training.\n   *  - cache_norm: Cache of the normalized inputs, of\n   *      shape (C, N*Hin*Win). Note: This is used for performance\n   *      during training.\n   */\n  N = nrow(X)\n\n  if (mode == 'train') {\n    # Compute channel-wise mean and variance\n    # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion.\n    #  - mean of total group is mean of subgroup means\n    #  - variance is the mean of the subgroup variances + the variance of the subgroup means\n    subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)\n    subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win)  # uncorrected variances\n    mean = rowMeans(subgrp_means)  # shape (C, 1)\n    var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win))  # shape (C, 1)\n    # Update moving averages\n    ema_mean_upd = mu*ema_mean + (1-mu)*mean\n    ema_var_upd = mu*ema_var + (1-mu)*var\n  }\n  else {\n    # Use moving averages of mean and variance during testing\n    mean = ema_mean\n    var = ema_var\n    ema_mean_upd = ema_mean\n    ema_var_upd = ema_var\n  }\n\n  # Normalize, shift, and scale\n  # norm = (X-mean)*(var+epsilon)^(-1/2)\n  #      = (X-mean) / sqrt(var+epsilon)\n  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)\n  norm = bias_multiply(centered, 1/sqrt(var+epsilon))  # shape (N, C*Hin*Win)\n  # out = norm*gamma + beta\n  scaled = bias_multiply(norm, gamma)  # shape (N, C*Hin*Win)\n  out = bias_add(scaled, beta)  # shape (N, C*Hin*Win)\n\n  # Save variable for backward pass\n  cache_mean = mean\n  cache_var = var\n  cache_norm = norm\n}\n";
        return string;
    }

    public Backward_output backward(Object object, Object object2, Object object3, Object object4, Object object5, Object object6, Object object7, Object object8, Object object9, Object object10, Object object11, Object object12, Object object13, Object object14, Object object15, Object object16, Object object17, Object object18) {
        String string = "source('scripts/nn/layers/batch_norm2d.dml') as mlcontextns;[dX, dgamma, dbeta] = mlcontextns::backward(dout, out, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm, X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, epsilon);";
        Script script = new Script(string);
        script.in("dout", object).in("out", object2).in("ema_mean_upd", object3).in("ema_var_upd", object4).in("cache_mean", object5).in("cache_var", object6).in("cache_norm", object7).in("X", object8).in("gamma", object9).in("beta", object10).in("C", object11).in("Hin", object12).in("Win", object13).in("mode", object14).in("ema_mean", object15).in("ema_var", object16).in("mu", object17).in("epsilon", object18).out("dX").out("dgamma").out("dbeta");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("dX");
        Matrix matrix2 = mLResults.getMatrix("dgamma");
        Matrix matrix3 = mLResults.getMatrix("dbeta");
        Backward_output backward_output = new Backward_output(matrix, matrix2, matrix3);
        return backward_output;
    }

    public String backward__docs() {
        String string = "backward = function(matrix[double] dout, matrix[double] out,\n                    matrix[double] ema_mean_upd, matrix[double] ema_var_upd,\n                    matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm,\n                    matrix[double] X, matrix[double] gamma, matrix[double] beta,\n                    int C, int Hin, int Win, string mode,\n                    matrix[double] ema_mean, matrix[double] ema_var,\n                    double mu, double epsilon)\n      return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {\n  /*\n   * Computes the backward pass for a 2D (spatial) batch normalization\n   * layer.\n   *\n   * Inputs:\n   *  - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).\n   *  - out: Outputs from the forward pass, of shape (N, C*Hin*Win).\n   *  - ema_mean_upd: Updated exponential moving average of the mean\n   *      from the forward pass, of shape (C, 1).\n   *  - ema_var_upd: Updated exponential moving average of the variance\n   *      from the forward pass, of shape (C, 1).\n   *  - cache_mean: Cache of the batch mean from the forward pass, of\n   *      shape (C, 1).  Note: This is used for performance during\n   *      training.\n   *  - cache_var: Cache of the batch variance from the forward pass,\n   *      of shape (C, 1).  Note: This is used for performance during\n   *      training.\n   *  - cache_norm: Cache of the normalized inputs from the forward\n   *      pass, of shape (C, N*Hin*Win).  Note: This is used for\n   *      performance during training.\n   *  - X: Input data matrix to the forward pass, of\n   *      shape (N, C*Hin*Win).\n   *  - gamma: Scale parameters, of shape (C, 1).\n   *  - beta: Shift parameters, of shape (C, 1).\n   *  - C: Number of input channels (dimensionality of input depth).\n   *  - Hin: Input height.\n   *  - Win: Input width.\n   *  - mode: 'train' or 'test' to indicate if the model is currently\n   *      being trained or tested.  During training, the current batch\n   *      mean and variance will be used to normalize the inputs, while\n   *      during testing, the exponential average of the mean and\n   *      variance over all previous batches will be used.\n   *  - ema_mean: Exponential moving average of the mean, of\n   *      shape (C, 1).\n   *  - ema_var: Exponential moving average of the variance, of\n   *      shape (C, 1).\n   *  - mu: Momentum value for moving averages.\n   *      Typical values are in the range of [0.9, 0.999].\n   *  - epsilon: Smoothing term to avoid divide by zero errors.\n   *      Typical values are in the range of [1e-5, 1e-3].\n   *\n   * Outputs:\n   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).\n   *  - dgamma: Gradient wrt `W`, of shape (C, 1).\n   *  - dbeta: Gradient wrt `b`, of shape (C, 1).\n   *\n   */\n";
        return string;
    }

    public String backward__source() {
        String string = "backward = function(matrix[double] dout, matrix[double] out,\n                    matrix[double] ema_mean_upd, matrix[double] ema_var_upd,\n                    matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm,\n                    matrix[double] X, matrix[double] gamma, matrix[double] beta,\n                    int C, int Hin, int Win, string mode,\n                    matrix[double] ema_mean, matrix[double] ema_var,\n                    double mu, double epsilon)\n      return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {\n  /*\n   * Computes the backward pass for a 2D (spatial) batch normalization\n   * layer.\n   *\n   * Inputs:\n   *  - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).\n   *  - out: Outputs from the forward pass, of shape (N, C*Hin*Win).\n   *  - ema_mean_upd: Updated exponential moving average of the mean\n   *      from the forward pass, of shape (C, 1).\n   *  - ema_var_upd: Updated exponential moving average of the variance\n   *      from the forward pass, of shape (C, 1).\n   *  - cache_mean: Cache of the batch mean from the forward pass, of\n   *      shape (C, 1).  Note: This is used for performance during\n   *      training.\n   *  - cache_var: Cache of the batch variance from the forward pass,\n   *      of shape (C, 1).  Note: This is used for performance during\n   *      training.\n   *  - cache_norm: Cache of the normalized inputs from the forward\n   *      pass, of shape (C, N*Hin*Win).  Note: This is used for\n   *      performance during training.\n   *  - X: Input data matrix to the forward pass, of\n   *      shape (N, C*Hin*Win).\n   *  - gamma: Scale parameters, of shape (C, 1).\n   *  - beta: Shift parameters, of shape (C, 1).\n   *  - C: Number of input channels (dimensionality of input depth).\n   *  - Hin: Input height.\n   *  - Win: Input width.\n   *  - mode: 'train' or 'test' to indicate if the model is currently\n   *      being trained or tested.  During training, the current batch\n   *      mean and variance will be used to normalize the inputs, while\n   *      during testing, the exponential average of the mean and\n   *      variance over all previous batches will be used.\n   *  - ema_mean: Exponential moving average of the mean, of\n   *      shape (C, 1).\n   *  - ema_var: Exponential moving average of the variance, of\n   *      shape (C, 1).\n   *  - mu: Momentum value for moving averages.\n   *      Typical values are in the range of [0.9, 0.999].\n   *  - epsilon: Smoothing term to avoid divide by zero errors.\n   *      Typical values are in the range of [1e-5, 1e-3].\n   *\n   * Outputs:\n   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).\n   *  - dgamma: Gradient wrt `W`, of shape (C, 1).\n   *  - dbeta: Gradient wrt `b`, of shape (C, 1).\n   *\n   */\n  N = nrow(X)\n  mean = cache_mean\n  var = cache_var\n  norm = cache_norm\n  centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)\n\n  if (mode == 'train') {\n    # Compute gradients during training\n    dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)\n    dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)\n    dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)\n    dvar = util::channel_sums((-1/2) * bias_multiply(centered, (var+epsilon)^(-3/2)) * dnorm,\n                              C, Hin, Win)  # shape (C, 1)\n    dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -1/sqrt(var+epsilon)), C, Hin, Win)\n    dmean_var_branch =  util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, Win)\n    dmean_var_branch = dmean_var_branch * dvar  # we can't use a function within an expression yet\n    dmean = dmean_norm_branch + dmean_var_branch  # shape (C, 1)\n    dX_norm_branch = bias_multiply(dnorm, 1/sqrt(var+epsilon))\n    dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)\n    dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)\n    dX = dX_norm_branch + dX_mean_branch + dX_var_branch  # shape (N, C*Hin*Win)\n  }\n  else {\n    # Compute gradients during testing\n    dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)\n    dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)\n    dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)\n    dX = bias_multiply(dnorm, 1/sqrt(var+epsilon))  # shape (N, C*Hin*Win)\n  }\n}\n";
        return string;
    }
}

