Problem when modifying the d2l.ai implementation of ResNet

Hello, I am experiencing weird behaviour when modifying the ResNet implementation in the d2l.ai online book along the lines recommended in the exercises. I am running this on a p2.xlarge instance, so on GPU. I do hope this is not related to issues with batch normalization?
https://github.com/apache/incubator-mxnet/issues/14357

Here is the code, mostly from the d2l.ai notebook, but with small changes:

import sys
sys.path.insert(0, '..')

import d2l
from mxnet import gluon, init, nd
from mxnet.gluon import nn

# This class has been saved in the d2l package for future use
class Residual(nn.Block):
    def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
                               strides=strides)
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
                                   strides=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()

    def forward(self, X):
        # MY CHANGE: BN, relu, conv replaces conv, BN, relu order
        Y = self.conv1(nd.relu(self.bn1(X)))
        Y = self.conv2(nd.relu(self.bn2(Y)))
        if self.conv3:
            X = self.conv3(X)
        return Y + X
        # ORIGINAL CODE:
        #Y = nd.relu(self.bn1(self.conv1(X)))
        #Y = self.bn2(self.conv2(Y))
        #if self.conv3:
        #    X = self.conv3(X)
        #return nd.relu(Y + X)

net = nn.Sequential()
net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),
        nn.BatchNorm(), nn.Activation('relu'),
        nn.MaxPool2D(pool_size=3, strides=2, padding=1))

def resnet_block(num_channels, num_residuals, first_block=False):
    blk = nn.Sequential()
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.add(Residual(num_channels, use_1x1conv=True, strides=2))
        else:
            blk.add(Residual(num_channels))
    return blk

net.add(resnet_block(64, 2, first_block=True),
        resnet_block(128, 2),
        resnet_block(256, 2),
        resnet_block(512, 2))
net.add(nn.GlobalAvgPool2D(), nn.Dense(10))

lr, num_epochs, batch_size, ctx = 0.05, 5, 256, d2l.try_gpu()
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier())
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch5(net, train_iter, test_iter, batch_size, trainer, ctx,
              num_epochs)

Running this, I get:

training on gpu(0)
epoch 1, loss nan, train acc 0.099, test acc 0.100, time 83.9 sec
epoch 2, loss nan, train acc 0.100, test acc 0.100, time 76.1 sec
epoch 3, loss nan, train acc 0.100, test acc 0.100, time 76.4 sec
epoch 4, loss nan, train acc 0.100, test acc 0.100, time 76.5 sec
epoch 5, loss nan, train acc 0.100, test acc 0.100, time 76.5 sec

Using the original code, everything is fine (as in the d2l notebook)

NaN loss is usually a sign of exploding gradients. Try to diminish your learning rate, with your code and a learning rate of 0.001 I got the following training logs:

training on gpu(0)
epoch 1, loss 1.0534, train acc 0.688, test acc 0.780, time 15.2 sec
epoch 2, loss 0.6392, train acc 0.799, test acc 0.811, time 13.9 sec
epoch 3, loss 0.5438, train acc 0.822, test acc 0.829, time 13.9 sec
epoch 4, loss 0.4913, train acc 0.837, test acc 0.842, time 13.9 sec
epoch 5, loss 0.4563, train acc 0.846, test acc 0.848, time 13.9 sec