Mxnet crashed executing backward() for a custom symbol


#1

I am working on implementing a simple custom symbol. Mxnet crashed while trying to run backward(). Here is the code:

class GlobalAvgPool2D(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
y = mx.nd.mean(in_data[0], axis=0, exclude=True)
self.assign(out_data[0], req[0], y)

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
    dims = in_grad[0].shape()
    m = 1
    for dim in dims[1:]: m *= dim
    self.assign(in_grad[0], req[0], in_grad[0]/float(m))

@mx.operator.register(“global_avg_pool2d”)
class GlobalAvgPool2DProp(mx.operator.CustomOpProp):
def init(self):
super(GlobalAvgPool2DProp, self).init(True)

def list_arguments(self):
    return ['data']

def list_outputs(self):
    return ['output']

def infer_shape(self, in_shapes):
    # return 3 lists representing inputs shapes, outputs shapes, and aux data shapes.
    return (in_shapes[0],), ([in_shapes[0][0]],), ()

def create_operator(self, ctx, in_shapes, in_dtypes):
    #  create and return the CustomOp class.
    return GlobalAvgPool2D()

Here are the error message:
terminate called after throwing an instance of ‘dmlc::Error’
what(): [20:21:18] src/operator/custom/custom.cc:358: Check failed: reinterpret_cast(params.info->callbacks[kCustomOpBackward])( ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()), reinterpret_cast<const int*>(req.data()), static_cast(ctx.is_train), params.info->contexts[kCustomOpBackward])

Stack trace returned 7 entries:
[bt] (0) /home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x276938) [0x7f2098282938]
[bt] (1) /home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x276d48) [0x7f2098282d48]
[bt] (2) /home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x3a9c6a) [0x7f20983b5c6a]
[bt] (3) /home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x397246) [0x7f20983a3246]
[bt] (4) /home/ubuntu/anaconda3/envs/mxnet_p27/bin/…/lib/libstdc++.so.6(+0xafc5c) [0x7f2180ed0c5c]
[bt] (5) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f2181f126ba]
[bt] (6) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f218153841d]


#2

I am not sure where the problem is, but have you seen this class: https://mxnet.incubator.apache.org/api/python/symbol/contrib.html?highlight=adaptiveavgpooling2d#mxnet.symbol.contrib.AdaptiveAvgPooling2D

I think, it should do similar thing, so you might not need to add your custom operator.