Memory profiling for MxNet

Hello,
I am using the bucketing module in MxNet to construct tree-structure neural networks. I have two different tree classes. One of them uses about 4Gb of memory during train and test. The other one which is a minor variation of the first class uses hundreds of gigabytes of memory. I used python’s memory_profiler and it shows that model.fit() is where the excess memory usage is happening. Does anyone know how I can profile the memory for the functions inside MxNet? If I am not mistaken memory_profiler does not work for nested functions. Any help is very much appreciated!
Thank you,
Forough.

Please share the code since it’s hard to tell what a minor variation is :slight_smile:

Thank you for your reply.
I was able to find where the leak was happening. It was not due to my tree classes. I am passing an object as the bucket_key to function sym_gen in the bucketing module and I figured that I was creating a new instance of that object every time my iterator calls next.
For future reference is there any specific profiler that you use for MxNet?
Thank you,
Forough.

@piiswrong any suggestions? Debugging on GPUs is tricky.
@forough - do you mind sharing an example of the good and of the ‘leaky’ code for the benefit of other readers?

Sure. Here is an example of the code.

I am using the bucketing module for constructing tree-structured neural networks. The problem was due to method next in my data_iterator:

Buggy code:

class BucketEqIterato(mx.io.DataIter):
    def __init__(...):    
        super(BucketEqIterator, self).__init__()
        # initialize some attributes
        # define the default bucket key and provide_data and provide_label   
        # construct bucket key hash functions  
        # form self.data accordingly

    def reset(self):
	    self.curr_idx = 0
                # shuffle the data if you wish

	    self.nddata = []
	    self.ndlabel = []
	    for i, buck in enumerate(self.data):

		    label = self.labels[i]
		    self.nddata.append([mx.ndarray.array(buck[k], dtype=self.dtype) for k in range(len(buck))])
		    self.ndlabel.append(mx.ndarray.array(label, dtype=self.dtype))

     def next(self):
         if self.curr_idx == len(self.idx):
	        	raise StopIteration
         i, j = self.idx[self.curr_idx]
         self.curr_idx += 1

         # indexing depends on your code design
         if self.major_axis == 1:
             data = self.nddata[i].T
             label = self.ndlabel[i].T
         else:
             data = self.nddata[i]
             label = self.ndlabel[i]

         d = mx.io.DataBatch(data, [label], pad=0,
	                 bucket_key = **bucketIndex(self.buckets[i], self.devFlag),**
	                 provide_data=[(self.data_name[i][self.default_bucket_key][j], (self.batch_size, self.vocabSize))
	                                if self.dataFlag[i][j] == 0 else
	                                (self.data_name[i][self.default_bucket_key][j], (self.batch_size, ))
	                                for j in range(len(self.data_name[i][self.default_bucket_key]))],
	                 provide_label=[(self.label_name, label.shape)])

class bucketIndex(Object):
    def __init__(self, index, devFlag=0):
	# initialize some attributes
    #define some desired methods

The problem was coming from all the bucketIndex objects that were created at each call to next (pretty trivial but annoyingly hidden!)

Correction is to just initialize the self.buckets[i] = objectIndex(…) in the init function of the original iterator class and access it in the call to next with:

 d = mx.io.DataBatch(data, [label], pad=0,
	                 bucket_key = **self.buckets[i],**
	                 provide_data=[(self.data_name[i][0][j], (self.batch_size, self.vocabSize))
	                                if self.dataFlag[i][j] == 0 else
	                                (self.data_name[i][0][j], (self.batch_size, ))
	                                for j in range(len(self.data_name[i][0]))],
	                 provide_label=[(self.label_name, label.shape)])

(It is nice if the interface supported tab, it was pretty painful to add the code :slight_smile: )