Gradients through sparse.dot


#1

I’m running a model that performs a few nd.sparse.dot(<CSRNDArray @gpu(0)>, <NDArray @gpu(0)>) operations on gpu in the forward pass. I get the following message when I run loss.backward() with the adam optimizer:

Storage type fallback detected:
_operator = backward_FullyConnected
input storage types = [row_sparse, default, default, ]
output storage types = [default, default, default, ]
params = {“no_bias” : False, “flatten” : True, “num_hidden” : 64, }
context.dev_mask = gpu
The operator with default storage type will be dispatched for execution. You’re seeing this warning message because the operator above is unable to process the given ndarrays with specified storage types, context and parameter. Temporary dense ndarrays are generated in order to execute the operator. You can set environment variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to 0 to suppress this warning.

Is this telling me my CSRNDArray is getting converted to a dense array to perform the backward pass? What’s the best way to confirm this? And if so, why does it say “row_sparse” in the warning message?

Thanks!


#2

Hi @cvitkom This depends on your model architecture.
For some reason the _backward_FC layer is getting row_sparse output gradient when calculating the input gradients. What layer do you have after FC? Usually the backward of sparse.dot and SparseEmbedding will generate “row_sparse” gradients.
(I think now profiler is enabled by default via pip install --pre).

One way to confirm this is to check the profiling result and see if _backward_FC is taking a long time. The fallback converts the NDArray but won’t affect the correctness. If you don’t see this as a bottleneck of training your network, it’s usually fine.

BTW Gluon doesn’t support sparsity yet. It will be available in one or two months.


#3

Thanks for the reply!

I use a number of fully connected layers throughout the architecture, so it’s a bit hard to tell what’s going on. Is there any way I can pinpoint which step(s) in the forward or backward pass are raising the warning? Can I give the _operators specific names or something?


#4

Just for completeness: I think updating to mxnet 1.2 fixed this. Thanks!