I am running into an issue when trying to compute multiple gradients. The neural network is a classifier. I want to compute the gradient of each of the (class) output nodes w.r.t. the input of the network. Here is the full code example:
from mxnet import autograd, init, nd
from mxnet.gluon import nn
import numpy as np
nb_classes = 10
def class_gradient(model, x, nb_classes, label=None):
x_ = nd.array(x)
if label is not None:
x_.attach_grad()
with autograd.record(train_mode=False):
preds = model(x_)
preds[:, label].backward(retain_graph=True, train_mode=False)
grads = x_.grad.asnumpy()
else:
grads = []
for i in range(nb_classes):
x_.attach_grad()
with autograd.record(train_mode=False):
preds = model(x_)
preds[:, 0].backward(retain_graph=True, train_mode=False)
grads.append(x_.grad.asnumpy())
return grads
# Create a simple CNN
net = nn.Sequential()
with net.name_scope():
net.add(
nn.Conv2D(channels=6, kernel_size=5, activation='relu'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(channels=16, kernel_size=3, activation='relu'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Flatten(),
nn.Dense(120, activation="relu"),
nn.Dense(84, activation="relu"),
nn.Dense(nb_classes)
)
net.initialize(init=init.Xavier())
# Random data in the shape of a small MNIST sample
data = np.random.rand(10, 1, 28, 28)
grads = class_gradient(net, data, nb_classes=nb_classes)
I am getting the following error:
File ".../mxnet/base.py", line 149, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [18:04:27] src/imperative/imperative.cc:373: Check failed: !AGInfo::IsNone(*i) Cannot differentiate node because it is not in a computational graph. You need to set is_recording to true or use autograd.record() to save computational graphs for backward. If you want to differentiate the same graph twice, you need to pass retain_graph=True to backward.
I’m using Python3.6, mxnet 1.2.0 and numpy 1.14.5. I have tried a few workarounds and code variants, but the error persists. I would much appreciate help on solving this problem.