Import export fails for float16 net after cast

I have no trouble casting a network to float16 and serializing it, but when I try to deserialize, I get an error:

>>> import mxnet as mx
>>> from gluoncv import model_zoo
>>> from mxnet import gluon
>>> import numpy as np
>>> 
>>> mx.__version__
'1.3.0'
>>> 
>>> ctx = [mx.gpu()]
>>> model_name = 'cifar_resnet20_v1'
>>> file_prefix = '/tmp/' + model_name
>>> params_file = file_prefix + '-0000.params'
>>> symbol_file = file_prefix + '-symbol.json'
>>> 
>>> net = model_zoo.get_model(model_name, pretrained=True, ctx=ctx)
>>> net.cast('float16')
>>> 
>>> net.hybridize()
>>> data = mx.nd.ones((1024, 3, 32, 32), ctx=ctx[0])
>>> out = net(data.astype('float16', copy=False))
>>> 
>>> 
>>> 
>>> net.export(file_prefix)
[19:48:15] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:109: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
>>> 
>>> deserialized_net = mx.gluon.nn.SymbolBlock.imports(symbol_file, ['data'], params_file)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/gluon/block.py", line 1025, in imports
    ret.collect_params().load(param_file, ctx=ctx)
  File "/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/gluon/parameter.py", line 918, in load
    self[name]._load_init(arg_dict[name], ctx)
  File "/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/gluon/parameter.py", line 243, in _load_init
    self.name, str(self.dtype), str(data.dtype))
AssertionError: Failed loading Parameter 'cifarresnetv10_stage3_conv5_weight' from saved params: dtype incompatible expected <type 'numpy.float32'> vs saved <type 'numpy.float16'>

It seems like there is no support for loading both model and parameters for float16 from files yet. There is only support of loading parameters only as float16.

I created a custom function for your example by copying and slightly modifying the code of gluon.nn.SymbolBlock.imports(...). For some reason, BatchNorm layer ignores casting to float16, so I have to leave it as float32. The code looks a little bit weird because of that, but it works.

import mxnet as mx
from gluoncv import model_zoo
from mxnet.gluon import SymbolBlock
from mxnet.symbol import symbol


def load_float16_model(symbol_file, input_names, param_file=None, ctx=None):
    sym = symbol.load(symbol_file)

    if isinstance(input_names, str):
        input_names = [input_names]

    inputs = [symbol.var(i) for i in input_names]
    ret = SymbolBlock(sym, inputs)
    ret.cast('float16') # explicit cast to float16

    if param_file is not None:
        for param_name, param in ret.collect_params().items():
            if "batchnorm" in param_name: # cast batchnorm layers back to float32
                param.cast('float32')

        ret.collect_params().load(param_file, ctx=ctx)
    return ret


ctx = [mx.gpu()]
model_name = 'cifar_resnet20_v1'

net = model_zoo.get_model(model_name, pretrained=True, ctx=ctx)
net.cast('float16')

net.hybridize()

data = mx.nd.ones((1024, 3, 32, 32), ctx=ctx[0])
out = net(data.astype('float16', copy=False))

file_prefix = '/tmp/' + model_name
net.export(file_prefix)

params_file = file_prefix + '-0000.params'
symbol_file = file_prefix + '-symbol.json'

deserialized_net = load_float16_model(symbol_file, ['data'], params_file, ctx=ctx[0])
output = deserialized_net(data.astype('float16', copy=False))
print(out)

I tested it with mxnet version 1.3.1 and gluoncv version 0.2.0. If I run this code I get:

[[ 5.74  -1.141  3.027 ... -4.594  1.127 -4.164]
 [ 5.74  -1.141  3.027 ... -4.594  1.127 -4.164]
 [ 5.74  -1.141  3.027 ... -4.594  1.127 -4.164]
 ...
 [ 5.74  -1.141  3.027 ... -4.594  1.127 -4.164]
 [ 5.74  -1.141  3.027 ... -4.594  1.127 -4.164]
 [ 5.74  -1.141  3.027 ... -4.594  1.127 -4.164]]
<NDArray 1024x10 @gpu(0)>

I have created a feature request to support this behaviour out of the box: https://github.com/apache/incubator-mxnet/issues/13147.