I need to build a mask symbol out of a list of indexes.
For example given a set of indexes idxs:
then I want a mask like the following
where idxs has shape [batch_size, num_indexes]
and mask has shape [batch_size, max_values]
At the moment I am using the following:
mask = mx.sym.sum(
Unfortunately, this takes a lot of gpu memory. It scales as batch_size * max_values * num_indexes. I was wondering if anyone here has some ideas on how to do this in a more efficient way.