Multinomial Sampling without replacement


I am using the python api for mxnet. I would like to use the operator mxnet.symbol.sample_multinomial to draw multiple samples WITHOUT replacement. What I could do is implement this in numpy and wrap it with CustomOp but then the sampling would happen on cpu. If possible I would like a solution that keeps everything on gpu.

Thanks a lot in advance for any suggestions.