Sharing parameters between two modules through arg_dict


I’m trying to share parameters between two BucketingModules and since there is no support to bind with a shared module, I have to somehow share the parameters around.

Calling get_params() from one Module and calling set_params() on the other one does the trick, but it’s slow since the parameters get copied from during the set_params() call. I would like to avoid copying as much as possible, i.e. copying only when contexts differ.

I’ve come up with the solution shown in the function update_params_async, but I believe it does not run correctly, since there’s no enforcing of NDArray sychronization.

What is the best practice when sharing parameters between bucketing modules?
Would the update_params_async function work if I enforce synchronization with mx.nd.waitall()?

import mxnet as mx 

# define a computational graph with just one variable
w = mx.sym.Variable("weight")
x = mx.sym.Variable('data')
out = mx.sym.MakeLoss(x * w)

def make_module(out):
    # create a module and assign it an optimizer to hold grad arrays
    mod = mx.mod.Module(symbol=out, label_names=[], data_names=['data'])
    mod.bind(data_shapes=[('data', (2,))], label_shapes=None)
    return mod 

def update_params_sync(copy_from, copy_to):
    # call Module.get_params() and Module.set_params() to enforce synchronization of modules
    # This is slow, because the NDArray objects get copied to the context of the "copy_to" module
    arg_params, aux_params = copy_from.get_params()
    copy_to.set_params(arg_params, aux_params)

def update_params_async(copy_from, copy_to):
    # set parameter values by assigning the "copy_to" module the NDArrays from the "copy_from" module
    # I'm not 100% sure that this is synchronized.
    arg_params = copy_from._exec_group.execs[0].arg_dict
    aux_params = copy_from._exec_group.execs[0].aux_dict
    for key, value in arg_params.items():
        copy_to._exec_group.execs[0].arg_dict[key] = value
    for key, value in aux_params.items():
        copy_to._exec_group.execs[0].aux_dict[key] = value

# create two modules
# mod_a will be updated and mod_b will get the updated params from mod_a
mod_a = make_module(out)
mod_b = make_module(out)

while True:    
    batch =[mx.nd.random.normal(0, 1, shape=(2,))], [])

    update_params_sync(mod_a, mod_b)