Equivalent to PyTorch scatter function


Is there a mxnet equivalent to PyTorch scatter function (http://pytorch.org/docs/0.3.0/tensors.html?highlight=scatter#torch.Tensor.scatter_)?

I initialize a (4,4) zeros ndarray and given the indices [[0,1],[1,0],[2,2],[3,3]] in which I want to update (to 1), is there a possibility to simultaneously update the ndarray to be the ndarray below?

CodeCogsEqn (1)


The following code section executes the task above.

import mxnet as mx

x = mx.nd.array([[0, 0, 0, 0],
                 [0, 0, 0, 0],
                 [0, 0, 0, 0],
                 [0, 0, 0, 0]], dtype='int32')

rows = mx.nd.arange(4).reshape((4,1))
cols = mx.nd.array([1,0,2,3]).reshape((4,1))

x[rows, cols] = 1

Since the ndarray is two dimensional we need to specify the rows and corresponding columns which we want to change. For this example rows are [[0],[1],[2],[3]] and columns are [[1],[0],[2],[3]] respectively.


mx.nd.scatter_nd() in MXNet achieves the same result, although the API is quite different. For your particular example, this code will do the trick:

mx.nd.scatter_nd(mx.nd.ones((4)), mx.nd.array([[1, 0, 2, 3], [0, 1, 2, 3]]), (4,4))

A few differences with Pytorch to keep in mind:

  • This function doesn’t scatter values onto an existing array. It creates a new one filled with zeros and scatters the values onto it.
  • You always select from dim 0 to M. There is no option to select dim M to N with M>0. In comparison, with Pytorch you can only select one dimension, but it can be any dimension (not just 0).