Updating a Constant in a distributed kvstore


I have a value attached associated to a Block that needs to be manually set once per training batch. When training using a local or device KVStore, I found the following works fine:

class MyBlock(Block):
    def __init__(self):
        with self.name_scope():
            self.my_cached_value_ = self.params.get_constant(
                # Some initial value (and a made up shape for this example):
                nd.zeros((5, 13))

    def forward(self, x, my_cached_value_):
        # A new NDArray of shape (5, 13):
        new_cached_value = some_function(x, my_cached_value_.data())
        with autograd.pause():

Unfortunately, when attempting to horizontally scale the training job (using a gluon.Trainer to control weight updates), this approach fails, since set_data attempts to reset the distributed KVStore on all contexts, raising an exception.

This leads to my question: Is there a recommended way to manually set a Gluon constant during distributed training?

I’ve been exploring different hacks to get around this problem, but so far they all come with their own complications. Answers like this one from @ThomasDelteil make me think that I am missing something basic here.

Thanks in advance!

Additional Notes/Attempts at a Solution

  • This is for an implementation of spectral normalization that supports distributed training. (The MXNet-provided SNGAN example only works on a single instance AFAICT.)
  • Using kvstore.push and kvstore.pull manually during training is problematic without rolling my own updater (that overrides the cached value rather than adding to it).
  • self.my_cached_value_.initialize(…, force_reinit=True) seems like an abuse of that method but does work on a single machine. I have yet to confirm that this works on a cluster.
  • Bypassing the KVStore by tracking a NumPy array on each context/machine would potentially work for my current use case, but I need to be sure that my implementation doesn’t mess up backprop.

Update: For the time being, I’ve pulled the cached value out of the kvstore entirely and am caching numpy arrays on each machine separately. I’ll post an update here if this works. Would love to see alternative approaches, though.