Best practices for prediction on a machine with multiple GPUs


MXNet offers great out-of-the-box features accelerating training on machines with multiple GPUs. Because of this, I’m finding that on my problem I’m spending a large percentage of the total job time computing evaluation metrics. In case it’s pertinent, I’m using the symbolic API with a custom training loop. I am computing metrics inside the training loop on a small amount of data for early stopping, and then at the end of training loop on the test data. In total this is >50% of the total runtime of the job.

So I’d like to know—are there best practices for speeding up prediction when multiple devices are present? I believe prediction occurs on just a single GPU even if many are present, which feels like poor resource utilization if most of the GPUs are idle for long periods. Any examples would be super useful.

Some specifics if it’s pertinent. I’m solving a prediction problem with ~100K outputs, with a final softmax layer. My two target metrics are perplexity and Recall-at-k (I think MXNet calls this TopKAccuracy). Since the recall metric is very expensive to compute (requires the rank of the true label) I’m early stopping on perplexity only. I only compute the recall metric on the test data.


Could you post the code snippet? If you have profiling results to share, that would be very helpful too. In the meantime, setting correct env variables would accelerate computation.


The metrics in MXNet converts NDArrays to numpy array, which is slow for a large prediction result due to its high communication cost. @szha is working on removing those conversions.


So I found two problems by reading through my code, the source and thinking a bit.

  1. As @eric-haibin-lin notes, some of the metrics call asnumpy() on the ndarrays very early. If you have many outputs this is extremely slow.
  2. I had written my own metric and was updating it by with a sequence of mod.forward(data), metric.update(mod.get_outputs()). Through reading the source I realised that this was quite different to calling mod.update_metric(metric, labels)