Gradient fetching


#1

I am trying to modify how a framework would store the gradients in memory and retrieve them while doing Forward or Back propagation, so I would like to know how actually mxnet fetches the gradients from memory and distribute them to the nodes of the network.


#2

In MXNet Gluon, the gradient is stored on the parameters. If you do everything from scratch, you can allocate space for the gradient using .attach_grad on the NDArray holding your parameter, have a look at this tutorial here, where a linear regression is done from scratch. You have the same example using the built-in layers here

You almost never need to do this manually though, as it is taken care for you when you use the built-in layers, like Conv2D or Dense. The gradients are calculated automatically using automatic differentiation (autograd) in the autograd.record() scope. Then you call .backward() on your loss and the gradients with respect to your loss is computed.

You can then apply the gradients to your parameters for example like that for the basic SGD algorithm:

for param in params:
    param[:] = param - lr * param.grad

But in most cases you would use a Trainer that would do that for you on the set of parameters that you have assigned to it. The Trainer takes an optimizer algorithm with corresponding parameters and a set of parameters to update. After running a batch, you call the Trainer and ask it to update your weights. The code looks like that:

trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.0001})
...
    # inside the training loop
    trainer.step(batch_size)

Read the tutorials on: MXNet the straight dope it will give you a better understanding of how the frameworks work, or you can get started with the 60 minutes intro course here

If you need more fine-grained control of the gradient computation themselves, you can use a custom operator so that you control the backward pass as well. Have a look at this CustomOp here that is used for CNN activation visualization.


#3

Would it be the same if I use module API and use the function backward of the module class?