Mxnet crashed executing backward() for a custom symbol


I am working on implementing a simple custom symbol. Mxnet crashed while trying to run backward(). Here is the code:

class GlobalAvgPool2D(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
y = mx.nd.mean(in_data[0], axis=0, exclude=True)
self.assign(out_data[0], req[0], y)

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
    dims = in_grad[0].shape()
    m = 1
    for dim in dims[1:]: m *= dim
    self.assign(in_grad[0], req[0], in_grad[0]/float(m))

class GlobalAvgPool2DProp(mx.operator.CustomOpProp):
def init(self):
super(GlobalAvgPool2DProp, self).init(True)

def list_arguments(self):
    return ['data']

def list_outputs(self):
    return ['output']

def infer_shape(self, in_shapes):
    # return 3 lists representing inputs shapes, outputs shapes, and aux data shapes.
    return (in_shapes[0],), ([in_shapes[0][0]],), ()

def create_operator(self, ctx, in_shapes, in_dtypes):
    #  create and return the CustomOp class.
    return GlobalAvgPool2D()

Here are the error message:
terminate called after throwing an instance of ‘dmlc::Error’
what(): [20:21:18] src/operator/custom/ Check failed: reinterpret_cast(>callbacks[kCustomOpBackward])( ptrs.size(), const_cast<void**>(, const_cast<int*>(, reinterpret_cast<const int*>(, static_cast(ctx.is_train),>contexts[kCustomOpBackward])

Stack trace returned 7 entries:
[bt] (0) /home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/ [0x7f2098282938]
[bt] (1) /home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/ [0x7f2098282d48]
[bt] (2) /home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/ [0x7f20983b5c6a]
[bt] (3) /home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/ [0x7f20983a3246]
[bt] (4) /home/ubuntu/anaconda3/envs/mxnet_p27/bin/…/lib/ [0x7f2180ed0c5c]
[bt] (5) /lib/x86_64-linux-gnu/ [0x7f2181f126ba]
[bt] (6) /lib/x86_64-linux-gnu/ [0x7f218153841d]


I am not sure where the problem is, but have you seen this class:

I think, it should do similar thing, so you might not need to add your custom operator.