I’m training a resent-50 model for image classification. I wanted to track the norm of my parameters so I wrote a custom call back to compute them and a custom mx.mod.Module with the fit() function overridden. The only thing I changed in the fit function was to add a call to get_params() to pull the parameters from the gpus after every update. My training loop in that function looks like this:
while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward_backward(data_batch) self.update() try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch) except StopIteration: end_of_batch = True self.update_metric(eval_metric, data_batch.label) if monitor is not None: monitor.toc_print() arg_params, aux_params = self.get_params() # the only line I added if batch_end_callback is not None: batch_end_params = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1
This works but I notice something odd: when I added the call to get_params() the trainin speed increased. Using the standard mx.mod.Module I got an average training speed of 750 samples/sec but with that line I get an average speed of 1050 samples/sec.
Any idea why this would speed things up?