Differentiating specific softmax output label with respect to input image


#1

Given a generic pretrained image classification model with softmax output and N classes, I want to compute the gradient of softmax output j (0 <= j < N) with respect to the input image pixel values. My approach so far has been the following:

import numpy as np
import mxnet as mx

sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch) #load pretrained model
all_layers = sym.get_internals()
net = all_layers['fc_output'] #Include all layers up to but not including the SoftmaxOutput layer. I don't want the cross entropy loss function in the symbol, just the softmax outputs
net = mx.symbol.softmax(data = net, name='softmax_label') #add back softmax activation function to last fully connected layer output without adding cross entropy loss
my_model = mx.mod.Module(symbol=net, context = mx.cpu())
my_model.bind(data_shapes = [('data', (1,3,250,250))], inputs_need_grad = True, for_training = True) #input is 250x250 RGB image
my_model.set_params(arg_params, aux_params, allow_missing = False) #set weights to pretrained values

img = get_image('/home/ubuntu/data-2/test_image_normalized.jpg', 250, 'resize') #numpy array with shape (1,3,250,250)
x = nx.nd.array(img)
d = mx.io.DataBatch([x])
label_index = 1 #specifies which softmax output we want to compute the gradient of

my_model.forward(d)
my_model.backward(out_grad=mx.ndarray.one_hot(indices = mx.nd.array([label_index]), depth = N)
) #this doesn't execute as written
w = my_model.get_input_grads()[0].asnumpy()

Unfortunately, the backward() call fails and yields the following error message:

[18:40:59] /home/ubuntu/src/mxnet/dmlc-core/include/dmlc/./logging.h:308: [18:40:59] src/ndarray/ndarray.cc:348: Check failed: from.shape() == to->shape() operands shape mismatchfrom.shape = (1,) to.shape=(1,2)

Stack trace returned 10 entries:
[bt] (0) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN4dmlc15LogMessageFatalD1Ev+0x3c) [0x7fbf65fdef0c]
[bt] (1) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN5mxnet10CopyFromToERKNS_7NDArrayEPS0_i+0x546) [0x7fbf66cc12f6]
[bt] (2) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN5mxnet4exec13GraphExecutor8BackwardERKSt6vectorINS_7NDArrayESaIS3_EEb+0xb3) [0x7fbf67082173]
[bt] (3) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(MXExecutorBackwardEx+0x314) [0x7fbf6704b4f4]
[bt] (4) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(ffi_call_unix64+0x4c) [0x7fbf99f90e20]
[bt] (5) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(ffi_call+0x2eb) [0x7fbf99f9088b]
[bt] (6) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(_ctypes_callproc+0x49a) [0x7fbf99f8b01a]
[bt] (7) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(+0x9fcb) [0x7fbf99f7efcb]
[bt] (8) /usr/bin/python3(PyObject_Call+0x47) [0x5b7167]
[bt] (9) /usr/bin/python3(PyEval_EvalFrameEx+0x4f06) [0x528d06]

---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
<ipython-input-167-3a4eabe9e4b8> in <module>()
----> 1 my_model.backward(h)

/usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/module/module.py in backward(self, out_grads)
    611         """
    612         assert self.binded and self.params_initialized
--> 613         self._exec_group.backward(out_grads=out_grads)
    614
    615     def update(self):

/usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/module/executor_group.py in backward(self, out_grads)
    545                 else:
    546                     out_grads_slice.append(grad.copyto(self.contexts[i]))
--> 547             exec_.backward(out_grads=out_grads_slice)
    548
    549     def update_metric(self, eval_metric, labels):

/usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/executor.py in backward(self, out_grads, is_train)
    229             mx_uint(len(out_grads)),
    230             ndarray,
--> 231             ctypes.c_int(is_train)))
    232
    233     def set_monitor_callback(self, callback):

/usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/base.py in check_call(ret)
    127     """
    128     if ret != 0:
--> 129         raise MXNetError(py_str(_LIB.MXGetLastError()))
    130
    131 if sys.version_info[0] < 3:

MXNetError: [18:40:59] src/ndarray/ndarray.cc:348: Check failed: from.shape() == to->shape() operands shape mismatchfrom.shape = (1,) to.shape=(1,2)

Stack trace returned 10 entries:
[bt] (0) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN4dmlc15LogMessageFatalD1Ev+0x3c) [0x7fbf65fdef0c]
[bt] (1) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN5mxnet10CopyFromToERKNS_7NDArrayEPS0_i+0x546) [0x7fbf66cc12f6]
[bt] (2) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN5mxnet4exec13GraphExecutor8BackwardERKSt6vectorINS_7NDArrayESaIS3_EEb+0xb3) [0x7fbf67082173]
[bt] (3) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(MXExecutorBackwardEx+0x314) [0x7fbf6704b4f4]
[bt] (4) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(ffi_call_unix64+0x4c) [0x7fbf99f90e20]
[bt] (5) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(ffi_call+0x2eb) [0x7fbf99f9088b]
[bt] (6) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(_ctypes_callproc+0x49a) [0x7fbf99f8b01a]
[bt] (7) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(+0x9fcb) [0x7fbf99f7efcb]
[bt] (8) /usr/bin/python3(PyObject_Call+0x47) [0x5b7167]
[bt] (9) /usr/bin/python3(PyEval_EvalFrameEx+0x4f06) [0x528d06]

Any help to correct my approach here would be much appreciated.


#2

How about the following for backward?

my_model.backward(out_grads=[mx.ndarray.one_hot(indices = mx.nd.array([label_index]), depth = N)])