Gluon Multi GPU Out of Memory Issues

I’ve been following the guides for working with the gluon API on multiple GPUs but I’m running into memory errors when attempting to sum the correct number of predictions during my validation loop:

mxnet.base.MXNetError: [15:24:56] src/storage/./pooled_storage_manager.h:143: cudaMalloc failed: out of memory

I’m still running out of memory when reducing the batch size to 2 and the training loop seems to work fine so I must be doing something wrong but haven’t been able to figure it out.

If I try waiting for the computations before printing a single prediction, I hit the error so presumably the real problem is happening with the mx.nd.argmax(net(X), axis=1) for X in data part of the code. I had hoped turning it into a generator would help, but it doesn’t seem to.

def valid_batch(data_it, label_it, ctx, net):
    data = gluon.utils.split_and_load(data_it, ctx)
    preds = (mx.nd.argmax(net(X), axis=1) for X in data)

    for pred in preds:
        print("datum", type(pred), pred.shape, pred.size)
        pred.wait_to_read()  # Error occurs here
        print(pred)

The shape of each NDArray in data is (4, 3, 480, 640).

If it helps, here’s all of the code together:

class DataIterLoader():
    def __init__(self, data_iter):
        self.data_iter = data_iter

    def __iter__(self):
        self.data_iter.reset()
        return self

    def __next__(self):
        batch = self.data_iter.__next__()
        assert len(batch.data) == len(batch.label) == 1
        data = batch.data[0]
        label = batch.label[0]
        return data, label

    def next(self):
        return self.__next__() # for Python 2


def forward_backward(net, data, label):
    with mx.autograd.record():
        losses = [loss_fn(net(X), Y) for X, Y in zip(data, label)]
    for l in losses:
        l.backward()


def train_batch(data_it, label_it, ctx, net, trainer):
    # Split the data batch and load them on GPUs
    data = gluon.utils.split_and_load(data_it, ctx)
    label = gluon.utils.split_and_load(label_it, ctx)
    # Compute gradient
    forward_backward(net, data, label)
    # Update parameters
    trainer.step(data_it.shape[0])


def valid_batch(data_it, label_it, ctx, net):
    data = gluon.utils.split_and_load(data_it, ctx)
    #labels = gluon.utils.split_and_load(label_it, ctx)
    preds = (mx.nd.argmax(net(X), axis=1) for X in data)

    for pred in preds:
        print("datum", type(pred), pred.shape, pred.size)
        pred.wait_to_read()  # out of memory error here
        print(pred)  
        break


# Load the RGB means for the training set, then determine the batch
# size
means = json.loads(open(args["means"]).read())
bat_size = config.BATCH_SIZE * args["num_devices"]

train_iter = mx.io.ImageRecordIter(
    path_imgrec=config.TRAIN_MX_REC,
    data_shape=(3, 480, 640),
    batch_size=bat_size,
    #rand_crop=True,
    rand_mirror=True,
    #rotate=15,
    #max_shear_ratio=0.1,
    mean_r=means["R"],
    mean_g=means["G"],
    mean_b=means["B"],
    preprocess_threads=args["num_devices"] * 2
)

val_iter = mx.io.ImageRecordIter(
    path_imgrec=config.VAL_MX_REC,
    data_shape=(3, 480, 640),
    batch_size=bat_size,
    mean_r=means["R"],
    mean_g=means["G"],
    mean_b=means["B"]
)

train_iter_loader = DataIterLoader(train_iter)
val_iter_loader = DataIterLoader(val_iter)

# Construct the checkpoints path
checkpoints_path = os.path.sep.join([args["checkpoints"],
                                     args["prefix"]])

# If there is no specific model starting epoch supplied, then
# initialize the network
if args["start_epoch"] <= 0:
    # Build the VGGNet architecture
    print("[INFO] Building network...")
    model = VGG19()

# Otherwise, a specific checkpoint was supplied
else:
    # Load the checkpoint from disk
    print("[INFO] Loading epoch {}...".format(args["start_epoch"]))
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        # Figure out checkpoint filename
        pad = 4 - len(str(args["start_epoch"]))
        zeroes = "0" * pad
        fname = args["prefix"] + "-" + zeroes + str(args["start_epoch"])
        # Load our model
        model = gluon.SymbolBlock.imports(args["prefix"] + "-symbol.json", ["data"], fname)

ctx = [mx.gpu(i) for i in range(0, args["num_devices"])]

model.initialize(mx.initializer.MSRAPrelu(), ctx=ctx)
model.hybridize()
trainer = gluon.Trainer(model.collect_params(), "sgd", {"learning_rate": args["learning_rate"]})

# Define our loss function
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

# Train the network
print("[INFO] Training network...")
for epoch in range(args["end_epoch"]):
    # Training Loop
    start = time()
    for d, l in train_iter_loader:  # start of mini-batch
        train_batch(d, l, ctx, model, trainer)
    mx.nd.waitall()  # Wait until all computations are finished to benchmark the time
    print("[Epoch {}] Training Time = {:.1f} sec".format(epoch, time() - start))

    # Validation loop
    correct, num = 0.0, 0.0
    for d, l in val_iter_loader:
        correct += valid_batch(d, l, ctx, model)
        num += d.shape[0]
        mx.nd.waitall()
    print("\tValidation Accuracy = {:.2f}".format(correct / num * 100))

    # Save a checkpoint
    path = os.path.sep.join([checkpoints_path, args["prefix"]])
    print("Saving checkpoint file {} to {}...".format(path, checkpoints_path))
    model.export(path, epoch=epoch)

And the full stack trace:

Traceback (most recent call last):
  File "train_vggnet.py", line 172, in <module>
    correct += valid_batch(d, l, ctx, model)
  File "train_vggnet.py", line 91, in valid_batch
    print(pred)
  File "/uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/ndarray/ndarray.py", line 189, in __repr__
    return '\n%s\n<%s %s @%s>' % (str(self.asnumpy()),
  File "/uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/ndarray/ndarray.py", line 1980, in asnumpy
    ctypes.c_size_t(data.size)))
  File "/uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/base.py", line 252, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [15:44:44] src/storage/./pooled_storage_manager.h:143: cudaMalloc failed: out of memory

Stack trace returned 10 entries:
[bt] (0) /uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x40123a) [0x2b9be9d3523a]
[bt] (1) /uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x401851) [0x2b9be9d35851]
[bt] (2) /uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x340f493) [0x2b9becd43493]
[bt] (3) /uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x341390e) [0x2b9becd4790e]
[bt] (4) /uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/libmxnet.so(void mxnet::CopyFromToDnsImpl<mshadow::gpu, mshadow::gpu>(mxnet::NDArray const&, mxnet::NDArray const&, mxnet::RunContext)+0x33a) [0x2b9bec7f85ca]
[bt] (5) /uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/libmxnet.so(void mxnet::CopyFromToImpl<mshadow::gpu, mshadow::gpu>(mxnet::NDArray const&, mxnet::NDArray const&, mxnet::RunContext, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&)+0x45d) [0x2b9bec81098d]
[bt] (6) /uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x2edcaab) [0x2b9bec810aab]
[bt] (7) /uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x2cb1494) [0x2b9bec5e5494]
[bt] (8) /uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x2cb8533) [0x2b9bec5ec533]
[bt] (9) /uufs/chpc.utah.edu/common/home/u6000791/venv/rana/lib/python3.5/site-packages/mxnet/libmxnet.so(+0x2cb8786) [0x2b9bec5ec786]

I think the problem comes from the fact that your training loop has no blocking call, and potentially you are flooding your GPU memory with your entire training set.
Try replacing your forward_backward with:

def forward_backward(net, data, label):
    with mx.autograd.record():
        losses = [loss_fn(net(X), Y) for X, Y in zip(data, label)]
    for l in losses:
        l.backward()
    for l in losses:
        l.wait_to_read()

Some errors in MXNet are only bubbled up when trying to read the value, so that’s why you are getting the error in the validation loop since during the training loop you are actually never trying to copy the output to CPU and it all happens in GPU. In all likeliness you might have had your malloc error during your training loop.

1 Like

Hi, to add to what @ThomasDelteil said, what kind of model do you have? Can you please share number of parameters? I think you are working on a semantic segmentation problem, and the image size you have (480x640) is too large? In my tests (without implementing the trick that Thomas said, I have data of size 256x256), and I can fit up to ~12-6 datums per GPU. Large input image, means large memory footprint during training, because the filters have large size. Some times I have to go as low as 128x128. I haven’t experimented very much with the “wait to read” command to be honest.

Hope this helps.

@ThomasDelteil you’re right that the memory issue is actually happening in the training loop. While the wait_to_read() calls didn’t solve my problem, they at least help isolate the problem. I’m running out of memory in the very first iteration of my training loop.

@feevos The large image size is unfortunately necessary in my case since the model is attempting to classify the presence/absence of small insects from low resolution video. Since the insects already make up such a small percentage of the image, I can’t afford to downsize each image for this use case.

That said, I had this working using the Module API so maybe this issue can be solved by doing a better job of translating my Module API model architecture to the gluon equivalent?

Here is the model architecture that was capable of training with batch sizes of 4:

class MxVGGNet:
    @staticmethod
    def build(classes):
        # data input
        data = mx.sym.Variable("data")

        # Block 1: (CONV => RELU) * 2 => POOL
        conv1_1 = mx.sym.Convolution(data=data, kernel=(3, 3),
                                     pad=(1, 1), num_filter=64, name="conv1_1")
        bn1_1 = mx.sym.BatchNorm(data=conv1_1, name="bn1_1")
        act1_1 = mx.sym.LeakyReLU(data=bn1_1, act_type="prelu",
                                  name="act1_1")
        conv1_2 = mx.sym.Convolution(data=act1_1, kernel=(3, 3),
                                     pad=(1, 1), num_filter=64, name="conv1_2")
        bn1_2 = mx.sym.BatchNorm(data=conv1_2, name="bn1_2")
        act1_2 = mx.sym.LeakyReLU(data=bn1_2, act_type="prelu",
                                  name="act1_2")
        pool1 = mx.sym.Pooling(data=act1_2, pool_type="max",
                               kernel=(2, 2), stride=(2, 2), name="pool1")
        do1 = mx.sym.Dropout(data=pool1, p=0.25)

        # Block 2: (CONV => RELU) * 2 => POOL
        conv2_1 = mx.sym.Convolution(data=do1, kernel=(3, 3),
                                     pad=(1, 1), num_filter=128, name="conv2_1")
        bn2_1 = mx.sym.BatchNorm(data=conv2_1, name="bn2_1")
        act2_1 = mx.sym.LeakyReLU(data=bn2_1, act_type="prelu",
                                  name="act2_1")
        conv2_2 = mx.sym.Convolution(data=act2_1, kernel=(3, 3),
                                     pad=(1, 1), num_filter=128, name="conv2_2")
        bn2_2 = mx.sym.BatchNorm(data=conv2_2, name="bn2_2")
        act2_2 = mx.sym.LeakyReLU(data=bn2_2, act_type="prelu",
                                  name="act2_2")
        pool2 = mx.sym.Pooling(data=act2_2, pool_type="max",
                               kernel=(2, 2), stride=(2, 2), name="pool2")
        do2 = mx.sym.Dropout(data=pool2, p=0.25)

        # Block 3: (CONV => RELU) * 4 => POOL
        conv3_1 = mx.sym.Convolution(data=do2, kernel=(3, 3),
                                     pad=(1, 1), num_filter=256, name="conv3_1")
        bn3_1 = mx.sym.BatchNorm(data=conv3_1, name="bn3_1")
        act3_1 = mx.sym.LeakyReLU(data=bn3_1, act_type="prelu",
                                  name="act3_1")
        conv3_2 = mx.sym.Convolution(data=act3_1, kernel=(3, 3),
                                     pad=(1, 1), num_filter=256, name="conv3_2")
        bn3_2 = mx.sym.BatchNorm(data=conv3_2, name="bn3_2")
        act3_2 = mx.sym.LeakyReLU(data=bn3_2, act_type="prelu",
                                  name="act3_2")
        conv3_3 = mx.sym.Convolution(data=act3_2, kernel=(3, 3),
                                     pad=(1, 1), num_filter=256, name="conv3_3")
        bn3_3 = mx.sym.BatchNorm(data=conv3_3, name="bn3_3")
        act3_3 = mx.sym.LeakyReLU(data=bn3_3, act_type="prelu",
                                  name="act3_3")
        conv3_4 = mx.sym.Convolution(data=act3_3, kernel=(3, 3),
                                     pad=(1, 1), num_filter=256, name="conv3_4")
        bn3_4 = mx.sym.BatchNorm(data=conv3_4, name="bn3_4")
        act3_4 = mx.sym.LeakyReLU(data=bn3_4, act_type="prelu",
                                  name="act3_4")
        pool3 = mx.sym.Pooling(data=act3_4, pool_type="max",
                               kernel=(2, 2), stride=(2, 2), name="pool3")
        do3 = mx.sym.Dropout(data=pool3, p=0.25)

        # Block 4: (CONV => RELU) * 4 => POOL
        conv4_1 = mx.sym.Convolution(data=do3, kernel=(3, 3),
                                     pad=(1, 1), num_filter=512, name="conv4_1")
        bn4_1 = mx.sym.BatchNorm(data=conv4_1, name="bn4_1")
        act4_1 = mx.sym.LeakyReLU(data=bn4_1, act_type="prelu",
                                  name="act4_1")
        conv4_2 = mx.sym.Convolution(data=act4_1, kernel=(3, 3),
                                     pad=(1, 1), num_filter=512, name="conv4_2")
        bn4_2 = mx.sym.BatchNorm(data=conv4_2, name="bn4_2")
        act4_2 = mx.sym.LeakyReLU(data=bn4_2, act_type="prelu",
                                  name="act4_2")
        conv4_3 = mx.sym.Convolution(data=act4_2, kernel=(3, 3),
                                     pad=(1, 1), num_filter=512, name="conv4_3")
        bn4_3 = mx.sym.BatchNorm(data=conv4_3, name="bn4_3")
        act4_3 = mx.sym.LeakyReLU(data=bn4_3, act_type="prelu",
                                  name="act4_3")
        conv4_4 = mx.sym.Convolution(data=act4_3, kernel=(3, 3),
                                     pad=(1, 1), num_filter=512, name="conv4_4")
        bn4_4 = mx.sym.BatchNorm(data=conv4_4, name="bn4_4")
        act4_4 = mx.sym.LeakyReLU(data=bn4_4, act_type="prelu",
                                  name="act4_4")
        pool4 = mx.sym.Pooling(data=act4_4, pool_type="max",
                               kernel=(2, 2), stride=(2, 2), name="pool4")
        do4 = mx.sym.Dropout(data=pool4, p=0.25)

        # Block 5: (CONV => RELU) * 4 => POOL
        conv5_1 = mx.sym.Convolution(data=do4, kernel=(3, 3),
                                     pad=(1, 1), num_filter=512, name="conv5_1")
        bn5_1 = mx.sym.BatchNorm(data=conv5_1, name="bn5_1")
        act5_1 = mx.sym.LeakyReLU(data=bn5_1, act_type="prelu",
                                  name="act5_1")
        conv5_2 = mx.sym.Convolution(data=act5_1, kernel=(3, 3),
                                     pad=(1, 1), num_filter=512, name="conv5_2")
        bn5_2 = mx.sym.BatchNorm(data=conv5_2, name="bn5_2")
        act5_2 = mx.sym.LeakyReLU(data=bn5_2, act_type="prelu",
                                  name="act5_2")
        conv5_3 = mx.sym.Convolution(data=act5_2, kernel=(3, 3),
                                     pad=(1, 1), num_filter=512, name="conv5_3")
        bn5_3 = mx.sym.BatchNorm(data=conv5_3, name="bn5_3")
        act5_3 = mx.sym.LeakyReLU(data=bn5_3, act_type="prelu",
                                  name="act5_3")
        conv5_4 = mx.sym.Convolution(data=act5_3, kernel=(3, 3),
                                     pad=(1, 1), num_filter=512, name="conv5_4")
        bn5_4 = mx.sym.BatchNorm(data=conv5_4, name="bn5_4")
        act5_4 = mx.sym.LeakyReLU(data=bn5_4, act_type="prelu",
                                  name="act5_4")
        pool5 = mx.sym.Pooling(data=act5_4, pool_type="max",
                               kernel=(2, 2), stride=(2, 2), name="pool5")
        do5 = mx.sym.Dropout(data=pool5, p=0.25)

        # Block 6: FC => RELU layers
        flatten = mx.sym.Flatten(data=do5, name="flatten")
        fc1 = mx.sym.FullyConnected(data=flatten, num_hidden="4096",
                                    name="fc1")
        bn6_1 = mx.sym.BatchNorm(data=fc1, name="bn6_1")
        act6_1 = mx.sym.LeakyReLU(data=bn6_1, act_type="prelu",
                                  name="act6_1")
        do6 = mx.sym.Dropout(data=act6_1, p=0.5)

        # Block 7: FC => RELU layers
        fc2 = mx.sym.FullyConnected(data=do6, num_hidden=4096,
                                    name="fc2")
        bn7_1 = mx.sym.BatchNorm(data=fc2, name="bn7_1")
        act7_1 = mx.sym.LeakyReLU(data=bn7_1, act_type="prelu",
                                  name="act7_1")
        do7 = mx.sym.Dropout(data=act7_1, p=0.5)

        # Softmax classifier
        fc3 = mx.sym.FullyConnected(data=do7, num_hidden=classes,
                                    name="fc3")
        model = mx.sym.SoftmaxOutput(data=fc3, name="softmax")

        return model

And here is the Gluon architecture that I’m now having memory trouble with:

class VGG19(gluon.HybridBlock):
    def __init__(self, **kwargs):
        super(VGG19, self).__init__(**kwargs)
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')

            # Block 1
            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1,
                                        weight_initializer=mx.init.Xavier(rnd_type='gaussian',
                                                                          factor_type='out',
                                                                          magnitude=2),
                                        bias_initializer='zeros'))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 2
            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 3
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 4
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 5
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 6
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",
                                       bias_initializer="zeros"))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            # Block 7
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",
                                       bias_initializer="zeros"))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            self.output = nn.Dense(2, weight_initializer="normal", bias_initializer="zeros")

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.output(x)
        return x
1 Like

How much memory do you have on your GPU ?

can you try to run this:

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn

class VGG19(gluon.HybridBlock):
    def __init__(self, **kwargs):
        super(VGG19, self).__init__(**kwargs)
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')

            # Block 1
            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1,
                                        weight_initializer=mx.init.Xavier(rnd_type='gaussian',
                                                                          factor_type='out',
                                                                          magnitude=2),
                                        bias_initializer='zeros'))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 2
            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 3
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 4
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 5
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 6
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",
                                       bias_initializer="zeros"))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            # Block 7
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",
                                       bias_initializer="zeros"))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            self.output = nn.Dense(2, weight_initializer="normal", bias_initializer="zeros")

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.output(x)
        return x

net = VGG19()
ctx=mx.gpu()
net.initialize(ctx=ctx)
net.hybridize(static_shape=True, static_alloc=True)
for i in range(10):
    with mx.autograd.record():
        out = net(mx.nd.ones((4, 3, 480, 640), ctx=ctx))
    out.backward()
    print(out.asnumpy())

I peak at 15GB used and then stays around 11GB.

Alternatively, can you try batch size of 1, 2, 3 and see if you hit the same issue?

You can also try a different architecture, have a look at the gluonCV model comparison, https://gluon-cv.mxnet.io/model_zoo/classification.html, VGG is actually worse than almost everything in terms of speed / memroy consumption / accuracy. Try a ResNet50.

Here is a tutorial on fine tuning in Gluon: https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/gluon_from_experiment_to_deployment.html

Here is another one from Gluon-cv https://gluon-cv.mxnet.io/build/examples_classification/transfer_learning_minc.html

1 Like

@auslaner,

It seems in my case that going Multi-GPU I get Cuda Malloc errors as well, the reason is the Kvstore is by default on GPU-0 and there isn’t enough space to store the gradients updates and the model on it. Switching to store=‘local’ works but is fairly slow since the updates are done on CPU.

I’d suggest switching to VGG16 for example to get some extra space if you are still very keen on VGG, otherwise try switching to ResNet50 for best trade off in performance / accuracy.

Here is a script that should be roughly compatible with your existing things, using resnet50 and batch size of 16 try it out:

import os, time
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.gluon.data.vision import transforms

args = {
    'num_devices':4,
    'checkpoints':'checkpoints',
    'prefix':'test_model',
    'learning_rate': 0.001,
    'jitter_param': 0.2,
    'start_epoch':0,
    'end_epoch':5
}

config = {
    'BATCH_SIZE': 16,
    #'VAL_MX_REC': 'val.rec',
    #'TRAIN_MX_REC': 'train.rec',
    'MEAN': '0.485, 0.456, 0.406',
    'STD': '0.229, 0.224, 0.225',
}

def forward_backward(net, data, label):
    with mx.autograd.record():
        losses = [loss_fn(net(X), Y) for X, Y in zip(data, label)]
    for l in losses:
        l.backward()
        l.wait_to_read()
        

def train_batch(data_it, label_it, ctx, net, trainer):
    # Split the data batch and load them on GPUs
    data = gluon.utils.split_and_load(data_it, ctx)
    label = gluon.utils.split_and_load(label_it, ctx)
    # Compute gradient
    forward_backward(net, data, label)
    # Update parameters
    trainer.step(data_it.shape[0])


def valid_batch(data_it, label_it, ctx, net, metric):
    data = gluon.utils.split_and_load(data_it, ctx)
    labels = gluon.utils.split_and_load(label_it, ctx)

    preds = [mx.nd.argmax(net(X), axis=1) for X in data]

    for pred, label in zip(preds, labels):
        metric.update(label, pred)

#######################
# Data
#######################

batch_size = config["BATCH_SIZE"] * args["num_devices"]

if 'TRAIN_MX_REC' in config and 'VAL_MX_REC' in config:
    training_dataset = gluon.data.RecordFileDataset(config['TRAIN_MX_REC'])
    val_dataset = gluon.data.RecordFileDataset(config['VAL_MX_REC'])
else:
    train_dataset = gluon.data.ArrayDataset(mx.nd.ones((200, 480, 640, 3)), mx.nd.ones((200,))) # dummy 200 data points
    val_dataset = gluon.data.ArrayDataset(mx.nd.ones((200, 480, 640, 3)), mx.nd.ones((200,))) # dummy 200 data points

# mean and std for normalizing image value in range (0,1)
mean = [float(x) for x in config['MEAN'].split(',')]
std = [float(x) for x in config['STD'].split(',')]

jitter_param = args['jitter_param']

training_transform = transforms.Compose([
    transforms.RandomFlipLeftRight(),
    transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param, saturation=jitter_param),
    transforms.RandomLighting(jitter_param),
])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_iter_loader = gluon.data.DataLoader(train_dataset.transform_first(training_transform).transform_first(transform), batch_size, shuffle=True, num_workers=8)
val_iter_loader = gluon.data.DataLoader(val_dataset.transform_first(transform), batch_size, shuffle=False, num_workers=8)

#######################
# Model
#######################

# Construct the checkpoints path
checkpoints_path = os.path.sep.join([args["checkpoints"],
                                     args["prefix"]])

ctx = [mx.gpu(i) for i in range(0, args["num_devices"])] if args["num_devices"] > 0 else [mx.cpu()]

print("[INFO] Building network...")
model = gluon.model_zoo.vision.resnet50_v2(pretrained=True, ctx=ctx)
with model.name_scope():
    model.output = nn.Dense(2, weight_initializer="normal", bias_initializer="zeros")

model.output.initialize(mx.initializer.MSRAPrelu(), ctx=ctx)
model.hybridize(static_alloc=True, static_shape=True)

#########################
# Training Loop
#########################

# Trainer
trainer = gluon.Trainer(model.collect_params(), "adam", {"learning_rate": args["learning_rate"]}, update_on_kvstore=True, kvstore='device')

# Define our loss function
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

# Train the network
print("[INFO] Training network...")
for epoch in range(args["end_epoch"]):
    # Training Loop
    start = time.time()
    for i, (d, l) in enumerate(train_iter_loader):  # start of mini-batch
        train_batch(d, l, ctx, model, trainer)
        print(i)
    mx.nd.waitall()  # Wait until all computations are finished to benchmark the time
    print("[Epoch {}] Training Time = {:.1f} sec".format(epoch, time.time() - start))

    # Validation loop
    metric = mx.metric.Accuracy()
    for d, l in val_iter_loader:
        valid_batch(d, l, ctx, model, metric)
    mx.nd.waitall()
    print("\tValidation Accuracy = {:.2f}".format(metric.get()[1]))

    # Save a checkpoint
    print("Saving checkpoint file to {}...".format(checkpoints_path))
    model.export(checkpoints_path, epoch=epoch)
    model.save_parameters("{}-{}-gluon.params".format(checkpoints_path, epoch))
1 Like

Hi, the problem is not going MultiGPU. Even single gpu has large memory footprint. Your network (removed bias where unnecessary to reduce params):

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn

class VGG19(gluon.HybridBlock):
    def __init__(self, **kwargs):
        super(VGG19, self).__init__(**kwargs)
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')

            # Block 1
            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1,
                                        weight_initializer=mx.init.Xavier(rnd_type='gaussian',
                                                                          factor_type='out',
                                                                          magnitude=2),use_bias=False
                                        ))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 2
            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 3
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 4
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 5
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            self.features.add(nn.Flatten())
            
            # Block 6
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            # Block 7
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            self.output = nn.Dense(2, weight_initializer="normal", bias_initializer="zeros")

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.output(x)
        return x

has ~600M parameters. The majority of which comes from the Dense_1 layer (the input to this is your problem):

net = VGG19()
ctx=mx.cpu()
net.initialize(ctx=ctx)
#net.hybridize(static_shape=True, static_alloc=True)
net.summary(mx.nd.ones((1, 3, 480, 640)))

Even if you use MobileNet convolutions, the problem will remain. What you can try and do, is reduce further the size of the input to the first dense layer. You can do so by adding more convolution layers (a deeper network - the deeper the better usually) that summarizes more the last conv feature, or use more aggressive MaxPooling in the last conv layer (e.g. pool = 4, stride = 4 or higher). For example, this modified architecture:

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn

class VGG19(gluon.HybridBlock):
    def __init__(self, **kwargs):
        super(VGG19, self).__init__(**kwargs)
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')

            # Block 1
            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1,
                                        weight_initializer=mx.init.Xavier(rnd_type='gaussian',
                                                                          factor_type='out',
                                                                          magnitude=2),use_bias=False
                                        ))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 2
            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 3
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 4
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 5
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
                        

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # @@@@@@@@@@@@@@@@@@@@@@@@ MOD here @@@@@@@@@@@@@@@@@@@@@@@@@@@@
            self.features.add(nn.Conv2D(1024, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            self.features.add(nn.Conv2D(1024, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))
            
            self.features.add(nn.Conv2D(2048, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))
            # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
        
        
            self.features.add(nn.Flatten())
            
            # Block 6
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            # Block 7
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            self.output = nn.Dense(2, weight_initializer="normal", bias_initializer="zeros")

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.output(x)
        return x

has “only” ~87M params. Play with this idea, and you can reduce the memory footprint even further.

All the best,
Foivos

edit: I strongly recommend, as @ThomasDelteil suggested, to move to a better architecture for usage (ResNet, DenseNet etc). Even implementing them on your own, if you don’t want the pre-defined models - from scratch (based on the papers), is relatively easy.

2 Likes