How to use gluon's dataloader in mxnet's symbol API?


#1

For example, the fit function of module.

Thanks


#2

Hi @kli-nlpr,

I have used Symbol DataIters in Gluon code but not the other way around, but it should be possible by wrapping the Gluon DataLoader and modifying the interface. What’s your use-case for sticking with Symbolic versus moving to Gluon?

Something like the following could work:

def DataLoaderIter(data_loader):
    for data, label in data_loader:
        data_desc = mx.io.DataDesc(name='data', shape=data.shape, dtype=data.dtype)
        label_desc = mx.io.DataDesc(name='label', shape=label.shape, dtype=label.dtype)
        batch = mx.io.DataBatch(data=[data], label=[label], provide_data=[data_desc], provide_label=[label_desc])
        yield batch
dataset = mx.gluon.data.dataset.ArrayDataset(X, y)
data_loader = mx.gluon.data.DataLoader(dataset, batch_size=5)
for batch in DataLoaderIter(data_loader):
   assert isinstance(batch, mx.io.DataBatch)
   assert batch.data[0].shape == (5, 3)
   assert batch.label[0].shape == (5, 1)

#3

And for completeness, to go the other way around from Symbolic DataIter to Gluon DataLoader you can try:

def DataIterLoader(data_iter):
    data_iter.reset()
    for batch in data_iter:
        assert len(batch.data) == 1
        data = batch.data[0]
        assert len(batch.label) == 1
        label = batch.label[0]
        yield data, label
data_iter_loader = DataIterLoader(data_iter)
for X_batch, y_batch in data_iter_loader:
   assert X_batch.shape == (5, 3)
   assert y_batch.shape == (5, 1)

#4

Thank you very much for your reply. I use mxnet’s symbol interface because I want to use mxnet-memonger, which is not supported by gluon interface yet. thanks


#5

The full working example~

import mxnet as mx
import logging
logging.getLogger().setLevel(logging.INFO)

X=mx.nd.random_normal(shape=(100,10),ctx=mx.cpu())
Y=mx.nd.ones((100),ctx=mx.cpu())
dataset = mx.gluon.data.dataset.ArrayDataset(X, Y)
gluon_data_loader = mx.gluon.data.DataLoader(dataset, batch_size=10)

class SimpleIter(object):
    def __init__(self, gluon_data_loader):
        self.gluon_data_loader=gluon_data_loader
        self.gluon_data_loader_iter=iter(self.gluon_data_loader)
        
        data,label=next(self.gluon_data_loader_iter)
        data_desc = mx.io.DataDesc(name='data', shape=data.shape, dtype=data.dtype)
        label_desc = mx.io.DataDesc(name='softmax_label', shape=label.shape, dtype=label.dtype)
        
        self.gluon_data_loader_iter=iter(self.gluon_data_loader)
        
        self._provide_data = [data_desc]
        self._provide_label = [label_desc]

    def __iter__(self):
        return self

    def reset(self):
        self.gluon_data_loader_iter=iter(self.gluon_data_loader)

    def __next__(self):
        return self.next()

    @property
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        data,label=next(self.gluon_data_loader_iter)
        batch = mx.io.DataBatch(data=[data], label=[label], provide_data=self._provide_data,
                                provide_label=self._provide_label)
        return batch


net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(net, name='fc1', num_hidden=64)
net = mx.sym.Activation(net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(net, name='fc2', num_hidden=2)
net = mx.sym.SoftmaxOutput(net, name='softmax')

mod = mx.mod.Module(symbol=net,
                    context=mx.cpu(),
                    data_names=['data'],
                    label_names=['softmax_label'])

train_iter=SimpleIter(gluon_data_loader)


# fit the module
mod.fit(train_iter,
        eval_data=train_iter, # set eval = train_iter ~
        optimizer='sgd',
        optimizer_params={'learning_rate':0.1},
        eval_metric='acc',
        num_epoch=8)

#6

Much more complete than my answer! Good work!