I have a very simple usecase that seems to be way too difficult in mxnet.
This is my input:
x = mx.nd.array([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10, 11, 12]])
I want to select/pick/take/whatever items on each row. For example on the first row the first 2 elements, then on the second row the 2nd two elements and again the first two elements on the 3rd row.
indices = mx.nd.array([[0,1],[1,2], [0,1]])
Pick complains that my indices dimensions is too big, well that makes no sense.
Take picks entire rows for no reason.
I expect the result to be like
y = [[1, 2], [6,7], [9,10]]
So how does one go about this?