I am trying to implement a LogSparse Transformer.
Like this one, https://openai.com/blog/sparse-transformer/ or this one https://papers.nips.cc/paper/8766-enhancing-the-locality-and-breaking-the-memory-bottleneck-of-transformer-on-time-series-forecasting.pdf
I managed to reduce memory consumption through Module API and the use of MXNet Memonger, however, I have some issues about sparse attention.
I feel like we miss something in the nd.array sparse api. The CSR format can only handle 2D data and the transformers produces more than 2D tensors.
Am I wrong ? Does MXNet has all the tools to implement sparse attention ? Maybe there are some tricks to overcome the limitations.