If I use batchnorm with global stats and fix the gamma
- that should be a reasonable approximation of feature normalization, except for the beta
. Is there any way to fix the beta
of BatchNorm
as well?
Hi @dmadeka,
Correction: use beta.lr_mult=0
instead of center=False
. See next answer for details.
Sure, that’s possible. You need to set center=False
and scale=False
on the BatchNorm
layer if you want to zero-center the data, and scale to unit variance. Overall the effect will be similar to standard data input normalisation, but subtly different. Instead of using the true normalisation statistics of the training data (calculated across the whole dataset), the BatchNorm
global stats will updated iteratively (often called the ‘running stats’). So there will certainly be differences at the start of training (since they are randomly initialised), and depending on your momentum
parameter, the statistics maybe more skewed to the most recent batches of data (so it might be a good idea to increase this a little).
mx.gluon.nn.BatchNorm(center=False, scale=False, momentum=0.99)
That doesn’t really work unfortunately
My mistake! So checking again, it turns out that setting center=False
and scale=False
on mx.gluon.nn.BatchNorm
also disables the initial zero-centering and unit variance scaling (before scaling by gamma and shifting by beta). You get all or nothing: e.g. both initial zero-centering and beta shifting or neither.
My first thought was to set .grad_req
to null
to avoid gradient calculation of beta
and gamma
but this once again disables the initial scaling and shifting too. Given this, my recommendation would be to get the learning rate multiplers for beta
and gamma
to 0 using lr_mult
. So the running stats are still calculated but beta
and gamma
don’t change from 0 and 1 respectively (i.e. no beta
shifting or gamma
scaling).
class SimpleNet(gluon.nn.HybridBlock):
def __init__(self,**kwargs):
super(SimpleNet, self).__init__(**kwargs)
with self.name_scope():
self.bn = gluon.nn.BatchNorm()
self.bn.beta.lr_mult = 0
self.bn.gamma.lr_mult = 0
self.dense = gluon.nn.Dense(1)
def hybrid_forward(self, F, x):
x1 = self.bn(x)
x2 = self.dense(x1)
return x1, x2