As shown in following code, I want to choose the first row and thrid row because their corresponding index are “1” and second row’s corresponding index is “0”. However, if I use a[b], I cannot get what I want.
How to achieve the effect I want?
a = mx.nd.array([[1,2,3,],[4,5,6],[7,8,9]])
b = mx.nd.array([1,0,1])
a[b]
[[4. 5. 6.]
[1. 2. 3.]
[4. 5. 6.]]
<NDArray 3x3 @cpu(0)>
# I want to get:
# [[1. 2. 3.]
# [7. 8. 9.]]