Gradient Accumulation in Module

Hi! I am trying to implement a simple MLP on MNIST with gradients accumulation using MX Module. accum_step is number gradients accumulation steps. I am doing the following:

  1. Bind the model with grad_req=“add” parameter
  2. Run forward(batch) and backward() accum_step - 1 times.
  3. On accum_step iteration I run model.update() and multiply the model gradients by zero:
    model._exec_group.grad_arrays *= 0

My problem is that the model is not training at all, i.e the score is not changing.
(Without gradient accumulation steps and with grad_req=‘write’ model trains perfectly)

data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
# fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)

embedding = mx.symbol.Variable('data_2')
fc_2 = mx.sym.FullyConnected(embedding, num_hidden=10, no_bias = True)

all_label = mx.symbol.Variable('softmax_label')
softmax = mx.symbol.SoftmaxOutput(data=fc_2, label=all_label, name='softmax')

n_epoch = 10

accum_step = 5
accum = True
if accum:
    _grad_req = 'add'
    batch_size = 20
else:
    _grad_req = 'write'
    batch_size = 100
# basedir = os.path.dirname(__file__)
# get_data.get_mnist(os.path.join(basedir, "data"))

train_dataiter = mx.io.MNISTIter(
    image=os.path.join("mnist", "train-images-idx3-ubyte"),
    label=os.path.join("mnist", "train-labels-idx1-ubyte"),
    data_shape=(784,),
    batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
    image=os.path.join("mnist", "t10k-images-idx3-ubyte"),
    label=os.path.join("mnist", "t10k-labels-idx1-ubyte"),
    data_shape=(784,),
    batch_size=batch_size, shuffle=True, flat=True, silent=False)

_rescale = 0.25
opt = optimizer.SGD(learning_rate=0.01, momentum=0.9, rescale_grad=_rescale)

mod = mx.mod.Module(act2, label_names=None)

mod.bind(data_shapes=train_dataiter.provide_data, grad_req=_grad_req)
mod.init_params()
mod.init_optimizer(optimizer=opt)

model_2_input_shape = (batch_size, 64)
model_2_data_shapes = [('data_2', model_2_input_shape)]
model_2_data_names = ['data_2']
label_names = ['softmax_label']

model_2 = mx.mod.Module(symbol=softmax, data_names=model_2_data_names, label_names=label_names)
model_2.bind(data_shapes=model_2_data_shapes, label_shapes=train_dataiter.provide_label, inputs_need_grad=True, grad_req=_grad_req)
model_2.init_params()
model_2.init_optimizer(optimizer=opt)

metric = mx.metric.create('acc')

grads_of_model_2_input_debug = mx.nd.zeros((batch_size, 64))

for i_epoch in range(n_epoch):
    for i_iter, batch in enumerate(train_dataiter):

        mod.forward(batch)
        model_2_data = [mod.get_outputs()[0]]
        model_2_input = mx.io.DataBatch(model_2_data, label=batch.label)
        model_2.forward(model_2_input)
        model_2.update_metric(metric, batch.label)
        model_2.backward()

        grads_of_model_2_input = model_2.get_input_grads()[0]
        mod.backward(out_grads=[grads_of_model_2_input])
        if accum:
            if i_iter % accum_step == 0 and i_iter > 0:
                model_2.update()
                mod.update()

                for i in range(len(model_2._exec_group.grad_arrays)):
                    model_2._exec_group.grad_arrays[i][0] *= 0

                for i in range(len(mod._exec_group.grad_arrays)):
                    mod._exec_group.grad_arrays[i][0] *= 0

        else:
            model_2.update()
            mod.update()
            # model_2.update_metric(metric, batch.label)
            #
            # model_grads = mod._exec_group.grad_arrays
            # model_2_grads = model_2._exec_group.grad_arrays

    for name, val in metric.get_name_value():
        print('epoch %03d: %s=%f' % (i_epoch, name, val))

    metric.reset()
    train_dataiter.reset()

Hi oljike,
Did you solve this problem?

Hi, I am encounter the same problem, have you solve this problem?