I’m training a resent-50 model for image classification. I wanted to track the norm of my parameters so I wrote a custom call back to compute them and a custom mx.mod.Module with the fit() function overridden. The only thing I changed in the fit function was to add a call to get_params() to pull the parameters from the gpus after every update. My training loop in that function looks like this:
while not end_of_batch:
data_batch = next_data_batch
if monitor is not None:
monitor.tic()
self.forward_backward(data_batch)
self.update()
try:
# pre fetch next batch
next_data_batch = next(data_iter)
self.prepare(next_data_batch)
except StopIteration:
end_of_batch = True
self.update_metric(eval_metric, data_batch.label)
if monitor is not None:
monitor.toc_print()
arg_params, aux_params = self.get_params() # the only line I added
if batch_end_callback is not None:
batch_end_params = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)
nbatch += 1
This works but I notice something odd: when I added the call to get_params() the trainin speed increased. Using the standard mx.mod.Module I got an average training speed of 750 samples/sec but with that line I get an average speed of 1050 samples/sec.
Any idea why this would speed things up?