How to choose some specific rows according to 0/1 mask?

As shown in following code, I want to choose the first row and thrid row because their corresponding index are “1” and second row’s corresponding index is “0”. However, if I use a[b], I cannot get what I want.
How to achieve the effect I want?

a = mx.nd.array([[1,2,3,],[4,5,6],[7,8,9]])
b = mx.nd.array([1,0,1])
a[b]
[[4. 5. 6.]
 [1. 2. 3.]
 [4. 5. 6.]]
<NDArray 3x3 @cpu(0)>
# I want to get:
# [[1. 2. 3.]
#  [7. 8. 9.]]

Hi @zhoulukuan,

I believe this is what you’re trying to achieve here:

a = mx.nd.array([[1,2,3,],[4,5,6],[7,8,9]])

# b should be the indices you want to take.
b = mx.nd.array([0,2])

a.take(b, axis=0)

[[1. 2. 3.]
 [7. 8. 9.]]
<NDArray 2x3 @cpu(0)>

I don’t think this solves the problem the OP was asking about. This suggestion, while it achieves the same end, is limited in a couple of key situations. For example, suppose I want to do this type of masked selection in batches:

x = mx.nd.array([ [[1,2,3], [4,5,6], [7,8,9]], [[-1,-2,-3], [-4,-5,-6], [-7,-8,-9]] ], ctx=mx.cpu()) # shape (2, 3, 3)
y = mx.nd.array( [[1,0,1], [0,1,0]], ctx=mx.cpu() )
# Error due to incompatible shapes
# z = mx.nd.array( [[0, 2], [1]])
# mx.nd.take(x, z)
# desired output: 
# [ [1, 2, 3], [7, 8, 9], [-4, -5, -6] ]

This might be helpful:
https://mxnet.apache.org/api/python/docs/api/ndarray/contrib/index.html#mxnet.ndarray.contrib.boolean_mask

1 Like

@ThomasDelteil Oh, it’s good. Thanks for your help.