Switch_bucket() sets grad_req to 'write' by default and overwrites other settings


#1

I want to set the grad_req for all my parameters to 'add' by using specifying it in the bind() call on my top-level module.
The problem is that I use a BucketingModule. So when I call forward, it automatically calls the switch_bucket method, which then calls bind again without specifying grad_req and therefore setting it to 'write'.

Is this intended behavior? Can I somehow work around this while still using bucketing?

I already opened an issue on GitHub here: https://github.com/apache/incubator-mxnet/issues/10904


#2

Probably, the simplest solution would be to inherit from the BucketingModule and override the bind() argument to be ‘add’ by default.

Something like:

class MyBucketingModule:
     def bind(self, data_shapes, label_shapes=None, for_training=True,
              inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='add'):
         super(MyBucketingModule, self).bind(label_shapes, for_training, inputs_need_grad, force_rebind, shared_module, grad_req)