MXNet Forum

NDArray API: ix_(), cumsum(). Similar functions


#1

Hi!

I try to select elements from NDArray structure by indexes like that:

ix_valid = np.ix_(valid.asnumpy().astype(np.uint8) != 0)
vlabels = labels[mx.nd.array(ix_valid)]

where valid is MXNET NDArray type.

But this way is slower because I spend time to convert asnumpy.
Is there any way to implement this by MXNET NDArray only?

I am also interested function cumsum in MXNET context:
numpy. cumsum ( a , axis=None , dtype=None , out=None )[[source]]


#2

You could use mx.nd.where https://mxnet.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.where
E.g. mx.nd.where( (valid != 0), labels, mx.nd.zeros_like(labels)) would return you an NDArray, where elements are zero when the corresponding element in valid is equal zero, otherwise the element takes the value from labels. If you need the indices itself, then one would need to find a workaround for that.

cumsum function is currently not supported, but there is already a feature request for it: https://github.com/apache/incubator-mxnet/issues/13001