I am running into an issue when trying to compute multiple gradients. The neural network is a classifier. I want to compute the gradient of each of the (class) output nodes w.r.t. the input of the network. Here is the full code example:
from mxnet import autograd, init, nd from mxnet.gluon import nn import numpy as np nb_classes = 10 def class_gradient(model, x, nb_classes, label=None): x_ = nd.array(x) if label is not None: x_.attach_grad() with autograd.record(train_mode=False): preds = model(x_) preds[:, label].backward(retain_graph=True, train_mode=False) grads = x_.grad.asnumpy() else: grads =  for i in range(nb_classes): x_.attach_grad() with autograd.record(train_mode=False): preds = model(x_) preds[:, 0].backward(retain_graph=True, train_mode=False) grads.append(x_.grad.asnumpy()) return grads # Create a simple CNN net = nn.Sequential() with net.name_scope(): net.add( nn.Conv2D(channels=6, kernel_size=5, activation='relu'), nn.MaxPool2D(pool_size=2, strides=2), nn.Conv2D(channels=16, kernel_size=3, activation='relu'), nn.MaxPool2D(pool_size=2, strides=2), nn.Flatten(), nn.Dense(120, activation="relu"), nn.Dense(84, activation="relu"), nn.Dense(nb_classes) ) net.initialize(init=init.Xavier()) # Random data in the shape of a small MNIST sample data = np.random.rand(10, 1, 28, 28) grads = class_gradient(net, data, nb_classes=nb_classes)
I am getting the following error:
File ".../mxnet/base.py", line 149, in check_call raise MXNetError(py_str(_LIB.MXGetLastError())) mxnet.base.MXNetError: [18:04:27] src/imperative/imperative.cc:373: Check failed: !AGInfo::IsNone(*i) Cannot differentiate node because it is not in a computational graph. You need to set is_recording to true or use autograd.record() to save computational graphs for backward. If you want to differentiate the same graph twice, you need to pass retain_graph=True to backward.
I’m using Python3.6, mxnet 1.2.0 and numpy 1.14.5. I have tried a few workarounds and code variants, but the error persists. I would much appreciate help on solving this problem.