An example(predicting with pretrained nets.) does not run on gpu

I’m running an example. (predict with pretrained nets.). This runs well on cpu, but when I change the context from mx.cpu() to mx.gpu(0), this produces errors.

errors


---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
 in 
     42         print('probability=%f, class=%s' %(prob[i], labels[i]))
     43 
---> 44 predict('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true')

 in predict(url)
     35     # compute the predict probabilities
     36     mod.forward(Batch([img]))
---> 37     prob = mod.get_outputs()[0].asnumpy()
     38     # print the top-5
     39     prob = np.squeeze(prob)

/home/me/anaconda3/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py in asnumpy(self)
   1970             self.handle,
   1971             data.ctypes.data_as(ctypes.c_void_p),
-> 1972             ctypes.c_size_t(data.size)))
   1973         return data
   1974 

/home/me/anaconda3/lib/python3.7/site-packages/mxnet/base.py in check_call(ret)
    250     """
    251     if ret != 0:
--> 252         raise MXNetError(py_str(_LIB.MXGetLastError()))
    253 
    254 

MXNetError: [01:13:26] src/ndarray/ndarray_function.cu:45: Check failed: to->type_flag_ == from.type_flag_ (0 vs. 3) Source and target must have the same data type when copying across devices.

Stack trace returned 10 entries:
[bt] (0) /home/me/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x382d4a) [0x7f376dd5ad4a]
[bt] (1) /home/me/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x383381) [0x7f376dd5b381]
[bt] (2) /home/me/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x4df96e8) [0x7f37727d16e8]
[bt] (3) /home/me/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2ce0546) [0x7f37706b8546]
[bt] (4) /home/me/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2cf635a) [0x7f37706ce35a]
[bt] (5) /home/me/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2cf648b) [0x7f37706ce48b]
[bt] (6) /home/me/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2af2a24) [0x7f37704caa24]
[bt] (7) /home/me/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2af9aa3) [0x7f37704d1aa3]
[bt] (8) /home/me/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2af9cf6) [0x7f37704d1cf6]
[bt] (9) /home/me/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2af3134) [0x7f37704cb134]

code

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
# define a simple data batch
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

ctx = mx.gpu(0) # when I change this to mx.cpu(), this runs well.

sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-18', 0) #successful.
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))], 
         label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)
with open('synset.txt', 'r') as f:
    labels = [l.rstrip() for l in f]
    
def get_image(url, show=False):
    # download and show the image. Remove query string from the file name.
    fname = mx.test_utils.download(url, fname=url.split('/')[-1].split('?')[0])
    img = mx.image.imread(fname)
    if img is None:
        return None
    if show:
        plt.imshow(img.asnumpy())
        plt.axis('off')
    # convert into format (batch, RGB, width, height)
    img = mx.image.imresize(img, 224, 224) # resize
    img = img.transpose((2, 0, 1)) # Channel first
    img = img.expand_dims(axis=0) # batchify
    return img

def predict(url):
    img = get_image(url, show=True)
    # compute the predict probabilities
    mod.forward(Batch([img]))
    prob = mod.get_outputs()[0].asnumpy()
    # print the top-5
    prob = np.squeeze(prob)
    a = np.argsort(prob)[::-1]
    for i in a[0:5]:
        print('probability=%f, class=%s' %(prob[i], labels[i]))

predict('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true')

Solved.

In get_image function,

change return img to return img.astype('float32')

@Pilhoon_Jang, glad you found a solution to your problem. This is using the old module API, I would suggest looking rather at this tutorial, using the more modern and better Gluon API.