During model training (LSTM, LM) I compute ppl on a batch per few batch and noticed that ppl computation is taking lot of time. I disabled ppl computation on training batch and only compute the ppl on the dev data after each epoch. I am seeing that now dev ppl computation which used to take less then 10 mins is taking 40 mins.
- With ppl computation per few batch + ppl computation on dev was taking 1 hour for a epoch. Dev data computation was taking 10 mins.
- Now with only ppl computation on dev and one epoch takes more or less same time (1 hour). Here dev data computation takes 40 mins.
I was expecting this to reduce training time. By disabling ppl computation per few batch the epoch completes much faster but time is spent on dev data computation. I am not able to understand why it takes exactly the same time. Does mxnet does some lazy evaluation? Please help me to understand.
Is this related to https://github.com/apache/incubator-mxnet/issues/9571. I call mxnet.metric.Perplexity.get() to compute and get the ppl.