Simple NDArray/Symbol indexing

I have a very simple usecase that seems to be way too difficult in mxnet.

This is my input:

x = mx.nd.array([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10, 11, 12]])

I want to select/pick/take/whatever items on each row. For example on the first row the first 2 elements, then on the second row the 2nd two elements and again the first two elements on the 3rd row.

indices = mx.nd.array([[0,1],[1,2], [0,1]])

Pick complains that my indices dimensions is too big, well that makes no sense.
Take picks entire rows for no reason.

I expect the result to be like

y = [[1, 2], [6,7], [9,10]]

So how does one go about this?

I have two ways of doing it.

One way is to use to take to select out the rows individually and then use pick to get out the indices.

So you have something like.

take_indices = mx.nd.array([[0, 0], [1, 1], [2, 2]]) # each row is repeated twice because you're taking two columns for each row in indices, otherwise it would be the number of cols.

indices = mx.nd.array([[0,1], [1,2], [0, 1]]) # same as before.

y = x.take(take_indices).pick(indices) # should give the correct answer

Note this is also quite general. If you wanted to change indices for example

indices = mx.nd.array([[0, 1, 2], [1, 2, 3], [0, 1, 2]])

Then take indices would be
take_indices = mx.nd.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]])

Then

 x.take(take_indices).pick(indices)

[[ 1.  2.  3.]
 [ 6.  7.  8.]
 [ 9. 10. 11.]]
<NDArray 3x3 @cpu(0)>

The other way is uglier and it involves doing pick multiple times with each column of indices and then concat on the result i.e

 mx.nd.concat(x.pick(indices[:, 0]).reshape(-1, 1), x.pick(indices[:, 1]).reshape(-1, 1))

[[ 1. 2.]

 [ 6. 7.]

 [ 9. 10.]]

<NDArray 3x2 @cpu(0)>

The problem for pick is that it does implicit indexing based on the order and so if your original array is 2D then you can only use 1D as your pick index.

I have answer here that explains the difference between pick and take as well:

What if I need different number of items from each row?
For example

indices = mx.nd.array([[0,1], [1,], [1, 2]])

Could you give me some info on how I could take the rows dynamically here?

Thanks.