Training is faster when get_params() is called every mini-batch

I’m training a resent-50 model for image classification. I wanted to track the norm of my parameters so I wrote a custom call back to compute them and a custom mx.mod.Module with the fit() function overridden. The only thing I changed in the fit function was to add a call to get_params() to pull the parameters from the gpus after every update. My training loop in that function looks like this:

        while not end_of_batch:
            data_batch = next_data_batch
            if monitor is not None:
                monitor.tic()
            self.forward_backward(data_batch)
            self.update()
            try:
                # pre fetch next batch
                next_data_batch = next(data_iter)
                self.prepare(next_data_batch)
            except StopIteration:
                end_of_batch = True

            self.update_metric(eval_metric, data_batch.label)

            if monitor is not None:
                monitor.toc_print()

            arg_params, aux_params = self.get_params() # the only line I added

            if batch_end_callback is not None:
                batch_end_params = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                 eval_metric=eval_metric,
                                                 locals=locals())
                for callback in _as_list(batch_end_callback):
                    callback(batch_end_params)
            nbatch += 1

This works but I notice something odd: when I added the call to get_params() the trainin speed increased. Using the standard mx.mod.Module I got an average training speed of 750 samples/sec but with that line I get an average speed of 1050 samples/sec.

Any idea why this would speed things up?

1 Like

@piiswrong that’s wired. because get_params() adds some synchronization?