MXNet - Use Batch Norm for Input Scaling

If I use batchnorm with global stats and fix the gamma - that should be a reasonable approximation of feature normalization, except for the beta. Is there any way to fix the beta of BatchNorm as well?

Hi @dmadeka,

Correction: use beta.lr_mult=0 instead of center=False. See next answer for details.

Sure, that’s possible. You need to set center=False and scale=False on the BatchNorm layer if you want to zero-center the data, and scale to unit variance. Overall the effect will be similar to standard data input normalisation, but subtly different. Instead of using the true normalisation statistics of the training data (calculated across the whole dataset), the BatchNorm global stats will updated iteratively (often called the ‘running stats’). So there will certainly be differences at the start of training (since they are randomly initialised), and depending on your momentum parameter, the statistics maybe more skewed to the most recent batches of data (so it might be a good idea to increase this a little).

mx.gluon.nn.BatchNorm(center=False, scale=False, momentum=0.99)

That doesn’t really work unfortunately :frowning:

My mistake! So checking again, it turns out that setting center=False and scale=False on mx.gluon.nn.BatchNorm also disables the initial zero-centering and unit variance scaling (before scaling by gamma and shifting by beta). You get all or nothing: e.g. both initial zero-centering and beta shifting or neither.

My first thought was to set .grad_req to null to avoid gradient calculation of beta and gamma but this once again disables the initial scaling and shifting too. Given this, my recommendation would be to get the learning rate multiplers for beta and gamma to 0 using lr_mult. So the running stats are still calculated but beta and gamma don’t change from 0 and 1 respectively (i.e. no beta shifting or gamma scaling).

class SimpleNet(gluon.nn.HybridBlock):
    def __init__(self,**kwargs):
        super(SimpleNet, self).__init__(**kwargs)

        with self.name_scope():
            self.bn = gluon.nn.BatchNorm()
            self.bn.beta.lr_mult = 0
            self.bn.gamma.lr_mult = 0
            self.dense = gluon.nn.Dense(1)

    def hybrid_forward(self, F, x):
        x1 = self.bn(x)
        x2 = self.dense(x1)
        return x1, x2