Speed with and without evaluation metric of the model


#1

During model training (LSTM, LM) I compute ppl on a batch per few batch and noticed that ppl computation is taking lot of time. I disabled ppl computation on training batch and only compute the ppl on the dev data after each epoch. I am seeing that now dev ppl computation which used to take less then 10 mins is taking 40 mins.

  • With ppl computation per few batch + ppl computation on dev was taking 1 hour for a epoch. Dev data computation was taking 10 mins.
  • Now with only ppl computation on dev and one epoch takes more or less same time (1 hour). Here dev data computation takes 40 mins.

I was expecting this to reduce training time. By disabling ppl computation per few batch the epoch completes much faster but time is spent on dev data computation. I am not able to understand why it takes exactly the same time. Does mxnet does some lazy evaluation? Please help me to understand.

Is this related to https://github.com/apache/incubator-mxnet/issues/9571. I call mxnet.metric.Perplexity.get() to compute and get the ppl.


#2

Hi, how did you know ppl computation is taking a lot of time? MXNet calls are asynchronous. You can run mxnet profiler https://github.com/apache/incubator-mxnet/blob/master/docs/faq/perf.md#profiler to figure out the time spent on each operator


#3

Thanks for the reply eric.

Earlier I was assuming that the ppl computation was taking time. But it turns out to be model update. Following is my observation.

batch1 -> run forward backward -> step
batch2 -> run forward backward -> step
.
.
.
batch100 -> run forward backward -> step -> evaluation metric is run (mx.metric.perpleixty) only on this batch

For each batch we do

forward
loss
loss.backward()
trainer.step()

If I call evaluation (in my case mx.metric.perplexity) on batch100 data (only on batch100 data, evaluation not run on other batches) after step (NOTE : In this case, data is already read, forward and backward is already computed, step is run which calls reduce grad and update), I am seeing that
significant time is spent before the first mx.metric.perpleixty.update() is called. My guess here is that, mxnet does lazy evaluation (not sure just a guess) and updates the parameter (or some other task) only if it is needed (i.e when it is called/accessed).

Instead of batch100 only, if I run the evaluation first on batch95 data first and then run evaluation on batch100, then time spent while running evaluation metric for batch95 is very high and evaluation metric run on batch100 only data is less. It takes less time for batch100 in this case as only 5 batches of data are run and model has to update (if the hypothesis on lazy evaluation is true) few parameters compared to evaluation on batch95 (as it has to update parameters for 95 batches of data).

This is my guess from the experimental observation. Let me know if this is incorrect and some other behavior of mxnet is responsible for this.

I am yet to get the profiling results. It dumps too many data (close to 50 GB only for few batches). I am trying to see how to get the evaluation only profiling result. That might help to understand better.


#4

In your training loop, do you ever call asnumpy()? That forces synchronization. If you only have the following code:

for batch in data_iter:
    with autograd.record():
        loss = net(batch)   
    loss.backward()
    trainer.step()
print('Reached here but not completed')
print(loss.asnumpy()) // this being printed means computation for loss is done

Reaching the print statement doesn’t mean all the forward/backward computations are completed.

In the perplexity metric, however, it calls .asnumpy() on the forward outputs, and forces synchronization. If you call mx.nd.waitall() or loss.asnumpy() before you update the metric, do you observe anything different?


#5

With mx.nd.waitall() and loss.asnumpy() before metric update didn’t change the speed. The speed remains the same. To summarize

  • Parameters are not immediately updated. They are updated when we need it. To force synchronization we can use mx.nd.waitall() [as per wiki this to be used for benchmarking] or call loss function. I believe Module already does that.

I noticed increase in speed when I disabled metric computation per few batch and was thinking that metric computation is slowing down the training. However, it was because by disabling the metric computation I was delaying the synchronization. When after many batches I called metric it forced synchronization and overall speed was more or less the same.

I now have to look into the implementation to increase the speed.

Thanks eric for the help. Now I have better understanding of the issue.


#6

I have encountered the same problem, can you tell me your implementation?