I’m just learning mxnet and got to the part where the API is mostly not thread safe and the there are only blocking calls (eg WaitAll()).
@kellen brought up Running Async Predictions how a “can_read” would help alleviate being able to queue forwards up, and reading results only when ready.
I tossed together only what I can describe as a hack, but seems to work with limited testing.
The idea is to make an cpu memory NDArray with a “token”, then after the forward, initiate an async copy on the net output to your cpu memory NDArray. Checking if its ready to read, involves checking to see if the “token” has been overwritten.
Like I mentioned, I’m new to mxnet, so this may be a bad idea, but I dont believe there to be any threading issues…And it doesn’t get away from any ndArray.SyncCopyFromCPU blocks.
The questionable parts do involve a const cast on the ndArray.GetData() pointer, but the C API isn’t marked const so I figured it’s safe from a CPU memory NDArray.
I wish there was a better way than polling. A callback would be preferable.
Anyways here’s the code to mark the buffer with a token:
namespace { /* Arbitrary number */ const float NOT_READY_TOKEN = 1234321.1234f; } std::vector<NDArray> outputs; outputs.reserve(executor->outputs.size()); executor->Forward(false); for (const auto& output : executor->outputs) { auto output_shape = output.GetShape(); const int shape_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<>()); /* Create output NDArray on */ NDArray out(output_shape, Context(kCPU, 0), false, output.GetDType()); const auto data = out.GetData(); /* Get last pointer in NDArray */ float& fdata = const_cast<float&>(*(data + shape_size - 1)); /* Set our not ready token */ fdata = NOT_READY_TOKEN; /* queue async copy to our output NDArray*/ output.CopyTo(&out); outputs.push_back(std::move(out)); }
And here is how to check if its ready to read:
bool check_output_ready(const std::vector<NDArray>& outputs) { for(const auto& out : outputs) { const auto data = out.GetData(); auto output_shape = out.GetShape(); const int shape_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<>()); const auto fdata = data[shape_size - 1]; if(fdata == NOT_READY_TOKEN) { return false; } } return true; }