cuDNN RNN implementation


Which one of the mxnet rnn interface uses cuDNN RNN implementation? I tried with mxnet.gluon.rnn and mxnet.rnn.LSTMCell. Didn’t find any where except for GRU it says cuDNN implementation. But not sure whether it means internally it calls cuDNN functions. My current implementation is slower and want to make sure that I am using correct function.


mxnet.gluon.rnn.LSTM and mxnet.gluon.rnn.GRU use cuDNN RNN when context is set to mx.gpu(). The mxnet.gluon.rnn.LSTMCell and its GRU variant are meant to be used when you have complex unrolling requirements.


Thanks @safrooze . Does mxnet.gluon.rnn.LSTMCell use cuDNN? Others doesn’t seem to have unroll (may be they use fixed sequence length)


LSTMCell does not use cuDNN’s RNN implementation. The beauty of cuDNN RNN implementation (and the upcoming MXNet native implementation for CPU context) is that the unrolling is done inside the operator based on the sequence length of the data. So you can still have a static computational graph with a variable sequence length!