Custom block with parameters not used in prediction; cannot infer shapes

TL;DR

If you have a custom block with parameters that are not used in prediction, and you compose this block with a Dense, then an exception is raised during shape inference.

Long Version

I’m writing a custom block that has some parameters that are not used in prediction, only for updates. I’d like to store them in the parameter dict so that they’re serialized with the block. They are passed to hybrid_forward(), but not used, so they aren’t part of the computation graph. When this block is used on its own, it works, and MXNet just issues a warning about unused parameters in hybrid_forward(). However, when I combine the block with a Dense layer, I get an exception about not being able to infer the shape of the parameter – even though it’s shape is given.

The exception is thrown by HybridBlock._deferred_infer_shape(), but it’s actually caused by HybridBlock._infer_attrs() (lines 862-866). This function builds a dict, sdict, of outputs and aux states from the computation graph; then, it iterates over the parameters in the parameter dict and sets an attribute (in this case, ‘shape’) using the value found in sdict. The problem with this is that if there is a key in the parameter dict that is not found in the computation graph, then a KeyError is raised.

Interestingly, the error is only raised when there is another layer in the network – presumably, one that needs to infer its shape.

One workaround – which I don’t like; it’s pretty hacky – is to add a no-op in hybrid_forward() that involves the unused parameter. I’ve illustrated this in the following minimum reproducible example, in the line that is commented out.

Is there a better way to get around this? Or will it require changes to MXNet?

As always, thanks in advance for your time and help!

Minimum Reproducible Example

import mxnet as mx
from mxnet import gluon, nd


class MyBlock(gluon.HybridBlock):

    def __init__(self, in_units, **kwargs):
        super().__init__(**kwargs)

        with self.name_scope():
            # Parameter used in prediction
            self.theta = self.params.get("theta", shape=(in_units, 1))
            # Auxiliary parameter used only in (manual) updates
            self.alpha = self.params.get_constant("alpha", (0,))

    def hybrid_forward(self, F, x, theta, alpha):
        # This version throws an exception
        return F.dot(x, theta)
        # This hack makes it work, but seems silly
        # return F.broadcast_add(F.dot(x, theta), 0 * alpha)


# Works but raises warning
net = gluon.nn.HybridSequential()
net.add(MyBlock(in_units=3))
net.initialize(mx.init.Uniform(), ctx=mx.cpu())
net.hybridize()
X = nd.random.normal(shape=(5, 3))
print(net(X))

# Throws an exception
net = gluon.nn.HybridSequential()
net.add(gluon.nn.Dense(4))
net.add(MyBlock(in_units=4))
net.initialize(mx.init.Uniform(), ctx=mx.cpu())
net.hybridize()
X = nd.random.normal(shape=(5, 3))
print(net(X))

Output

/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/sklearn/externals/joblib/externals/cloudpickle/cloudpickle.py:47: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
  import imp
/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py:548: UserWarning: Parameter myblock0_alpha is not used by any computation. Is this intended?

  out = self.forward(*args)
[[ 0.05662871]
 [-0.02758752]
 [ 0.0295723 ]
 [-0.01531329]
 [ 0.03770303]]
<NDArray 5x1 @cpu(0)>
/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py:548: UserWarning: Parameter myblock1_alpha is not used by any computation. Is this intended?
  out = self.forward(*args)
Traceback (most recent call last):
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py", line 811, in _call_cached_op
    for is_arg, i in self._cached_op_args]
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py", line 811, in <listcomp>
    for is_arg, i in self._cached_op_args]
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/parameter.py", line 543, in data
    return self._check_and_get(self._data, ctx)
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/parameter.py", line 234, in _check_and_get
    "num_features, etc., for network layers."%(self.name))
mxnet.gluon.parameter.DeferredInitializationError: Parameter 'dense0_weight' has not been initialized yet because initialization was deferred. Actual initialization happens during the first forward pass. Please pass one batch of data through the network before accessing Parameters. You can also avoid deferred initialization by specifying in_units, num_features, etc., for network layers.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py", line 797, in _deferred_infer_shape
    self.infer_shape(*args)
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py", line 870, in infer_shape
    self._infer_attrs('infer_shape', 'shape', *args)
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py", line 866, in _infer_attrs
    setattr(i, attr, sdict[i.name])
KeyError: 'myblock1_alpha'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/blondon/Code/ts/mre.py", line 38, in <module>
    print(net(X))
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py", line 548, in __call__
    out = self.forward(*args)
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py", line 915, in forward
    return self._call_cached_op(x, *args)
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py", line 813, in _call_cached_op
    self._deferred_infer_shape(*args)
  File "/Users/blondon/anaconda/envs/mxnet1.5/lib/python3.6/site-packages/mxnet/gluon/block.py", line 801, in _deferred_infer_shape
    raise ValueError(error_msg)
ValueError: Deferred initialization failed because shape cannot be inferred. 'myblock1_alpha'

Process finished with exit code 1