What is the corresponding function of index_fill_() of pytorch for mxnet/gluon?


#1

I want to translate a pytorch version code to mxnet/gluon code. I can not find the corresponding function in mxnet/gluon. Any advice will be appreciated. Thanks


#2

In Gluon things are easier with advanced indexing support . So the following code in PT:

x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
index = torch.tensor([0, 2])
x.index_fill_(1, index, -1)

would translate to the following in Gluon:

x = nd.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
x[:, [0, 2]] = -1

#3

Got it, thank you very much