How to implement BatchNorm to use Running Average/Var in gluon

Hi there,

Im new to gluon and mxnet, and Im trying to implement a BatchNorm Layer that instead of using mean and variance, to compute the output, it should be using the running average and variance instead. I have found the documentation of the BatchNorm online and it says that theres a use_global_stats flag, but some discussion online here says that it should not be used in training. Does anyone have an idea of how I can implement what im trying to do here?

Thanks

Code that I am using now, but its not working

class ShiftScaleLayer(HybridBlock):
def init(self, axis=-1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
use_global_stats=True, fuse_relu=False,
beta_initializer=‘zeros’, gamma_initializer=‘ones’,
running_mean_initializer=‘zeros’, running_variance_initializer=‘ones’,
in_channels=0, **kwargs):
super(ShiftScaleLayer, self).init(**kwargs)
self._kwargs = {‘axis’: axis}
self.fuse_relu = fuse_relu
if in_channels != 0:
self.in_channels = in_channels

    self.momentum = Variable('momentum', shape=in_channels, init=Constant(momentum))
    self.epsilon = Variable('epsilon', shape=in_channels, init=Constant(epsilon))
    self.momentum = BlockGrad(self.momentum)
    self.epsilon = BlockGrad(self.epsilon)
    self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
                                 shape=(in_channels,), init=gamma_initializer,
                                 allow_deferred_init=True,
                                 differentiable=scale)
    self.beta = self.params.get('beta', grad_req='write' if center else 'null',
                                shape=(in_channels,), init=beta_initializer,
                                allow_deferred_init=True,
                                differentiable=center)
    self.running_mean = Variable('running_mean', shape=in_channels, init=initializer.Zero)
    self.running_var = Variable('running_mean', shape=in_channels, init=initializer.One)

def batch_norm(self, F, X, gamma, beta, moving_mean, moving_var, eps, momentum, axis):
    mean = F.mean(data=X, axis=axis)
    var = F.mean(data=F.square(X - mean), axis=axis)

    temp = F.broadcast_mul(momentum, moving_mean)
    moving_mean = F.broadcast_add(temp, (1.0 - momentum))
    moving_mean = moving_mean * mean

    moving_var = momentum * moving_var + (1.0 - momentum) * var

    X_hat = F.broadcast_div((X - moving_mean), F.sqrt(moving_var + eps))
    Y = F.broadcast_add(F.broadcast_mul(gamma, X_hat), beta)

    return Y, moving_mean, moving_var

def hybrid_forward(self, F, x, gamma, beta):
    Y, self.running_mean, self.running_var = self.batch_norm(F, x, gamma, beta, self.running_mean, self.running_var,
                      self.epsilon, self.momentum, **self._kwargs)
    return Y