Memory leak with custom operation

Custom Operation using Gluon API, we seem memory leaking.
This is a test code.

=======

import gc
import tracemalloc
import numpy as np
from mxnet import nd, autograd, cpu
from mxnet.gluon import Block, Trainer, nn
from mxnet.operator import register, CustomOp, CustomOpProp

class CustomLoss(CustomOp):
	def __init__(self):
		super(CustomLoss, self).__init__()
	def forward(self, is_train, req, in_data, out_data, aux):
		z = in_data[0].asnumpy()
		self.assign(out_data[0], req[0], nd.array(z))
	def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
		z = in_data[0].asnumpy()
		self.assign(in_grad[0], req[0], nd.array(z))

@register("customloss")
class CustomLossProp(CustomOpProp):
	def __init__(self):
		super(CustomLossProp, self).__init__(True)
	def infer_shape(self, in_shape):
		input_shape = in_shape[0]
		batch_size = input_shape[0]
		out_shape = (batch_size, 100)
		return (input_shape, ), (out_shape,), ()
	def list_arguments(self):
		return ['data']
	def list_outputs(self):
		return ['output']
	def create_operator(self, ctx, in_shapes, in_dtypes):
		return CustomLoss()

class CustomBlock(Block):
	def __init__(self, **kwargs):
		super(CustomBlock, self).__init__(**kwargs)
	def forward(self, x):
		ctx = x.context
		return nd.Custom(x, op_type='customloss')

class Model(Block):
	def __init__(self, **kwargs):
		super(Model, self).__init__(**kwargs)
		with self.name_scope():
			self.dense = nn.Dense(100)
			self.custom = CustomBlock()
	def forward(self, x):
		return self.custom(self.dense(x))

model = Model()
model.initialize(ctx=cpu(0))
trainer = Trainer(model.collect_params(),'adam')
data = nd.array(np.random.randn(100000,100))
gc.collect()
tracemalloc.start()
snapshot1 = snapshot2 = tracemalloc.take_snapshot()
for epoch in range(10000):
	with autograd.record():
		loss = model(data)
	loss.backward()
	trainer.step(data.shape[0], ignore_stale_grad=True)
	del loss
	gc.collect()
	snapshot2 = tracemalloc.take_snapshot()
	top_stats = snapshot2.compare_to(snapshot1, 'lineno')
	print("epoch %d [ Memory leake Test ]"%epoch)
	for stat in top_stats[:10]:
		if str(stat).startswith('/usr/local/lib/python3.5/dist-packages/mxnet/base.py'):
			print(stat)
			break

=======

Is there any mistake in the code?
When you execute the code, the memory usage of mxnet/base.py increases infinitely.

epoch 0 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:254: size=11.9 KiB (+11.9 KiB), count=80 (+80), average=152 B
epoch 1 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:254: size=13.0 KiB (+13.0 KiB), count=91 (+91), average=147 B
epoch 2 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:254: size=14.2 KiB (+14.2 KiB), count=102 (+102), average=142 B
epoch 3 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:254: size=15.3 KiB (+15.3 KiB), count=113 (+113), average=139 B
epoch 4 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:285: size=16.5 KiB (+16.5 KiB), count=124 (+124), average=136 B
epoch 5 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:285: size=18.9 KiB (+18.9 KiB), count=142 (+142), average=136 B
epoch 6 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:285: size=21.3 KiB (+21.3 KiB), count=161 (+161), average=135 B
epoch 7 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:285: size=23.7 KiB (+23.7 KiB), count=179 (+179), average=135 B
epoch 8 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:285: size=26.0 KiB (+26.0 KiB), count=197 (+197), average=135 B
epoch 9 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:285: size=28.4 KiB (+28.4 KiB), count=215 (+215), average=135 B
epoch 10 [ Memory leake Test ]
/usr/local/lib/python3.5/dist-packages/mxnet/base.py:285: size=30.8 KiB (+30.8 KiB), count=234 (+234), average=135 B

Thank you for watching.
I hope it will be fixed if it is a matter of MXNet.
If you are having problems with my code, I want you to push the bad part.

add a synchronous call

a = loss.wait_to_read()

in your loop and you’ll see the “memory leak” going away.
What I think is happening:

  • MXNet operations are eagerly executed, they are enqueued in the backend and they return instantly. So when you pass data through your network, the operators are enqueued on the mxnet engine. If you enqueue operators faster than they are processed, you will see the memory increasing. For example, even without a synchronous call, if you try to reduce your data size, you will start seeing that when the data is processed faster than your tracemalloc is executed, you won’t see the memory increasing anymore.
2 Likes

Thank you Thomas! I was suffering for a long time.