As mentioned in a previous discussion, we are users of the MXNet Scala library, and we discovered that there’s a couple of threading-related issues in the MXNet Scala library.
We plan to fix this within the Scala library and contribute this fix back to the project. There’s a few ways we could fix this, some more invasive than the others. We’re looking for opinions on the best path forward, and were told this site was the best place to solicit such feedback.
Issues
The issue is that the non-naive MXNet engines are unsafe to call from more than one thread for the duration of program execution, rather than concurrently. This means that once MXNet is used from a given thread, it must never be called from any other thread. Calling from any other thread causes memory corruption which causes various types of failure when running at scale.
There’s two primary issues here:
Issue 1: Usability
Since many JVM environments are multi-threaded, it’s difficult to use MXNet safely. The usual approaches for taming a non-threadsafe library, such as thread-local instances or global mutexes, are not effective as MXNet requires all calls to the engine to be run (for the duration of process lifetime) from the same thread.
This means that any user wanting to use the Scala library within a “normal” JVM context, such as a server environment or within a processing system like Spark, needs to do extra work to ensure that only a single thread interacts with the MXNet library. The ways in which MXNet fails are surprising (and severe) when these issues occur, and can take significant work to diagnose.
Issue 2: Finalization of native-backed resources
Like other libraries which manage native-backed resources, MXNet uses the Java finalize()
method (finalizers), which are run when an object is garbage-collected, to clean up these resources for items which are not cleaned up explicitly by the library user. An example of this is NDArray objects where dispose()
is not explicitly called, in which case the finalizer disposes the native array on the coder’s behalf.
Although this is convenient, this also causes the memory corruption issue mentioned above, because finalizers always run on the special “Finalizer” thread. This is never the thread on which other MXNet work is performed, and as such always has the potential to corrupt memory. Our experiments indicate that running at reasonable scale this will happen with certainty.
As such the way in which finalizers are used in MXNet is unsafe. It, at least, needs to change.
Proposed fixes
We have two main proposals. The first is uninvasive but less convenient, the second will make the library easier to use but will lead to a lack of fine control by library users.
We also note a more radical change, for discussion’s sake.
Fix 1: Change finalizers to log leaks but not clear them up
This is effectively a minimal fix which will make the library usable. We propose that we:
- Remove the “dispose” functionality from the finalizers in the library
- Replace it with functionality which will log a warning that a leak has occurred
- Add a system property
mxnet.traceLeakedResources
which will trace the creation of all “leakable” resources, and add this to the warning message (this is controlled by a switch as the runtime cost of collecting this information may be high)
This allows users to understand where leaks originate and fix them, but doesn’t change the execution model of the library at all.
Since several of the leaks we’ve identified occur in the Scala library itself, we would use this functionality in order to trace and fix leaks in code we currently use (we have already done this well enough that we can run predictions at scale).
Fix 2: Implement a “dispatcher” thread in the Scala library
In our own use of MXNet, we’ve found it’s easiest to ensure that all use of MXNet occurs on one thread by mediating all use of MXNet through a singleton “dispatcher” thread in our codebase. This involves a library which starts a daemon thread, and then adding an object with an interface like the following:
object MXNetThread {
/**
* Runs a task on the MXNet thread.
*
* If the MXNet thread is the current thread, task is run immediately. Otherwise the task is passed to the MXNet thread and the current thread
* blocks until its return.
*/
def run[T]( task: => T ): T = ???
}
This allows one to dispatch MXNet work like this:
val result = MXNetThread.run {
// Do some work using MXNet
}
def someMethod(someArgs: String): String = MXNetThread.run {
// Do some work with MXNet
}
This proposal would be to push this work into the Scala library itself, and force all uses of the MXNet native library to go through this thread. This would make the library safe to use in all configurations without additional work from the users of the library. This would also allow us to rewrite the finalizers to dispose objects safely, as they can defer this disposal to the MXNet thread.
The downsides are:
- Forces all users of the library to use the thread. This shouldn’t be necessary when using the NaiveEngine, and others might want to implement their own ways of handling the threading issues with MXNet which fit better into their containers and so on.
- Causes possibly-surprising blocking when waiting for availability of the thread. A simple implementation won’t make this easily measurable.
- Very invasive to the library in general, as this needs to be implemented everywhere that the MXNet native library is used.
Sidenote: The “Nuclear” option
There’s a final more dramatic option, where the Scala library is pared down to only support prediction (developments like Gluon indicate that Scala isn’t likely to support the state-of-the-art in model construction and training in the future anyway). This would allow Fix 2 to be implemented more easily, and provide a “leaner” Scala library that could be used just to load in models created externally.
Request for comment
We’re definitely going to at least implement the first fix, but we’d like to know whether or not there’s any appetite for the second (or possibly even the third). We’d also be really interested to hear if anyone has any other ideas about how to solve this for users of the Scala library.
Thanks all!