I’m converting a chunk of numpy code to run on MXNet. I’m not sure how to convert this line without dropping back to numpy:
smallest_k = np.argpartition(vec, k)[0:k]
If you’re not familiar, the
np.partition does a partial sort on the vector, just separating the numbers into two groups (partitions) one of size
k, and the other
len-k such that all the values in the first group are less than all the values in the second. That’s
np.argpartition does the same thing but gives you the indices into the vector rather than the values.
Can I do this efficiently in MXNet (symbol or gluon) without writing a custom operator or dropping back to numpy? Advice appreciated.