How to choose some specific rows according to 0/1 mask?

As shown in following code, I want to choose the first row and thrid row because their corresponding index are “1” and second row’s corresponding index is “0”. However, if I use a[b], I cannot get what I want.
How to achieve the effect I want?

a = mx.nd.array([[1,2,3,],[4,5,6],[7,8,9]])
b = mx.nd.array([1,0,1])
a[b]
[[4. 5. 6.]
 [1. 2. 3.]
 [4. 5. 6.]]
<NDArray 3x3 @cpu(0)>
# I want to get:
# [[1. 2. 3.]
#  [7. 8. 9.]]

Hi @zhoulukuan,

I believe this is what you’re trying to achieve here:

a = mx.nd.array([[1,2,3,],[4,5,6],[7,8,9]])

# b should be the indices you want to take.
b = mx.nd.array([0,2])

a.take(b, axis=0)

[[1. 2. 3.]
 [7. 8. 9.]]
<NDArray 2x3 @cpu(0)>