Batch norm crashes with float16


I am trying to run SSD from gluoncv with float16. While I fully understand the model is not fully supporting float16 it seem the resnet head already crashes due to the batchnorm layer. See my minimal example below

import mxnet as mx
import gluoncv as gcv
from import presets
from matplotlib import pyplot as plt

ctx = [mx.gpu(0)]"" +
        "20012568/cbc2d6f6-a27d-11e6-94c3-d35a9cb47609.jpg", 'street.jpg')
image_list = ['street.jpg']

net = gcv.model_zoo.get_model('ssd_512_resnet50_v1_voc', pretrained=True)
net.set_nms(0.45, 200)

ax = None
for image in image_list:
    x, img = presets.ssd.load_test(image, short=512)
    x = x.as_in_context(mx.gpu(0))
    x = x.astype('float16', copy=False)
    ids, scores, bboxes = [xx[0].asnumpy() for xx in net(x)]
    ax = gcv.utils.viz.plot_bbox(img, bboxes, scores, ids,
                                class_names=net.classes, ax=ax)

The error I get is the following:

mxnet.base.MXNetError: Error in operator ssd0_resnetv10_batchnorm0_fwd: [00:11:51] src/operator/nn/ Check failed: (*in_type)[i] == dtype_param (2 vs. 0) This layer requires uniform type. Expected 'float32' v.s. given 'float16' at 'gamma'


Since you have solved it. I am posting the solution for reference.

SSD network uses SymbolBlock as inner backbone, which don’t recognize BatchNorm when casting the network.
One quick solution is instead of net.cast('float16'), you can iterate through net.collect_params().items(), if parameter not ends with gamma, beta, moving_mean, moving_var then cast it to float16.