Hi! I am trying to implement a simple MLP on MNIST with gradients accumulation using MX Module. accum_step is number gradients accumulation steps. I am doing the following:
- Bind the model with grad_req=“add” parameter
- Run forward(batch) and backward() accum_step - 1 times.
- On accum_step iteration I run model.update() and multiply the model gradients by zero:
model._exec_group.grad_arrays *= 0
My problem is that the model is not training at all, i.e the score is not changing.
(Without gradient accumulation steps and with grad_req=‘write’ model trains perfectly)
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
# fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
embedding = mx.symbol.Variable('data_2')
fc_2 = mx.sym.FullyConnected(embedding, num_hidden=10, no_bias = True)
all_label = mx.symbol.Variable('softmax_label')
softmax = mx.symbol.SoftmaxOutput(data=fc_2, label=all_label, name='softmax')
n_epoch = 10
accum_step = 5
accum = True
if accum:
_grad_req = 'add'
batch_size = 20
else:
_grad_req = 'write'
batch_size = 100
# basedir = os.path.dirname(__file__)
# get_data.get_mnist(os.path.join(basedir, "data"))
train_dataiter = mx.io.MNISTIter(
image=os.path.join("mnist", "train-images-idx3-ubyte"),
label=os.path.join("mnist", "train-labels-idx1-ubyte"),
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
image=os.path.join("mnist", "t10k-images-idx3-ubyte"),
label=os.path.join("mnist", "t10k-labels-idx1-ubyte"),
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False)
_rescale = 0.25
opt = optimizer.SGD(learning_rate=0.01, momentum=0.9, rescale_grad=_rescale)
mod = mx.mod.Module(act2, label_names=None)
mod.bind(data_shapes=train_dataiter.provide_data, grad_req=_grad_req)
mod.init_params()
mod.init_optimizer(optimizer=opt)
model_2_input_shape = (batch_size, 64)
model_2_data_shapes = [('data_2', model_2_input_shape)]
model_2_data_names = ['data_2']
label_names = ['softmax_label']
model_2 = mx.mod.Module(symbol=softmax, data_names=model_2_data_names, label_names=label_names)
model_2.bind(data_shapes=model_2_data_shapes, label_shapes=train_dataiter.provide_label, inputs_need_grad=True, grad_req=_grad_req)
model_2.init_params()
model_2.init_optimizer(optimizer=opt)
metric = mx.metric.create('acc')
grads_of_model_2_input_debug = mx.nd.zeros((batch_size, 64))
for i_epoch in range(n_epoch):
for i_iter, batch in enumerate(train_dataiter):
mod.forward(batch)
model_2_data = [mod.get_outputs()[0]]
model_2_input = mx.io.DataBatch(model_2_data, label=batch.label)
model_2.forward(model_2_input)
model_2.update_metric(metric, batch.label)
model_2.backward()
grads_of_model_2_input = model_2.get_input_grads()[0]
mod.backward(out_grads=[grads_of_model_2_input])
if accum:
if i_iter % accum_step == 0 and i_iter > 0:
model_2.update()
mod.update()
for i in range(len(model_2._exec_group.grad_arrays)):
model_2._exec_group.grad_arrays[i][0] *= 0
for i in range(len(mod._exec_group.grad_arrays)):
mod._exec_group.grad_arrays[i][0] *= 0
else:
model_2.update()
mod.update()
# model_2.update_metric(metric, batch.label)
#
# model_grads = mod._exec_group.grad_arrays
# model_2_grads = model_2._exec_group.grad_arrays
for name, val in metric.get_name_value():
print('epoch %03d: %s=%f' % (i_epoch, name, val))
metric.reset()
train_dataiter.reset()