Batchnorm running_var value depends on context


I have the issue in my model that the running_var parameter of some batchnorm layers are NaN after my dummy forward pass to initialize the parameters. While debugging, I discovered that the value of the running_var depends on the context I use. I assume this is a bug, as a model should behave the same no matter which context is used. Here is a minimum reproducible example:

import mxnet as mx
from mxnet.gluon import nn
from mxnet.gluon.block import HybridBlock
import numpy as np

def _conv3x3(channels, stride, in_channels):
    return nn.Conv2D(channels, kernel_size=3, strides=stride, padding=1,
                     use_bias=False, in_channels=in_channels)

def get_dummy_data(ctx):
    data_shape = (1, 3, 32, 32)  # cifar like
    shapes = ((1,) + data_shape[1:], (1,))
    return [mx.nd.array(np.zeros(shape), ctx=ctx) for shape in shapes]

def check_net_params_for_nan(net, ctx):
    has_nan = False
    for name, param in net.collect_params().items():
        if np.isnan(
            print('Param {} has nan values!'.format(name))
            has_nan = True
        if 'running_var' in name:
            print('Batchnorm running var values {}'.format(
    return has_nan

class test_net(HybridBlock):
    def __init__(self, **kwargs):
        super(test_net, self).__init__(**kwargs)

        self.body = nn.HybridSequential(prefix='')
        self.body.add(_conv3x3(64, 1, 0))
        with self.body.name_scope():
            self.body.add(_conv3x3(64, 2, 64))
            self.body.add(_conv3x3(64, 1, 64))

    def hybrid_forward(self, F, x, *args, **kwargs):
        return self.body(x)

num_gpus = 2
ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
net = test_net()
net.initialize(mx.init.Xavier(), ctx=ctx)

# dummy forward pass to initialize layers
with mx.autograd.record():
    # to make sure all params are initialized. Needs to re-run for networks that only execute layers with a certain
    # probability to make sure all layers are initialized
    for i in range(100):
        data, label = get_dummy_data(ctx[0])
        output = net(data)

assert not check_net_params_for_nan(net, ctx[0])

If I set num_gpus in this code to 0 its using the CPU and the output is, that all running_var values are 1.

If i set num_gpus to 1 or 2 its using GPUs the value of all running_var values is 2.6561329e-05

Can anyone reproduce this? I am using mxnet 1.3.1 on ubuntu, built from source.