GPU memory usage

What would be the expected memory usage for this model (~4M parameters)? When training on a single GPU with batch size 250+ it runs out of memory (memory is 11439MiB per GPU)

model = mx.gluon.nn.Sequential(prefix='model_')
with model.name_scope():           
    model.add(Dense(1800))
    model.add(Activation("relu"))
    model.add(BatchNorm())
    model.add(Dropout(0.25))

    for i in range(7):
        model.add(Dense(400))
        model.add(Activation("relu"))
        model.add(BatchNorm())
        model.add(Dropout(0.25))

    model.add(Dense(num_classes))
    model.add(BatchNorm())

Input shape: (1038708, 903)
Number of classes: 522 (one-hot labels)
batch_size=200
learning_rate = 0.01
loss = SoftmaxCrossEntropyLoss(sparse_label=False)

Could you post the complete code that triggers OOM? You can use dummy data (all zeros)

@piiswrong Thanks for helping!

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from mxnet import nd
        
        
class DatasetBuilder(object):
    def __init__(self):
        pass
    
    def buildDatasets(self):
        self.X_train = np.load("/dev/asin2vec/xTrain.npy").astype("float32")
        self.Y_train = np.load("/dev/asin2vec/yTrain.npy").astype("float32")
        
        self.X_valid = np.load("/dev/asin2vec/xVal.npy").astype("float32")
        self.Y_valid = np.load("/dev/asin2vec/y_val.npy").astype("float32")
  
        print("Loaded training and validation data")

db = DatasetBuilder()
db.buildDatasets()

print("Input shape: {}".format(db.X_train.shape))
print("Number of classes: {}".format(db.Y_train.shape[1]))

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import mxnet as mx
from mxnet import gpu, gluon, nd, autograd
from mxnet.gluon import HybridBlock, Block
from mxnet.gluon.nn import Embedding, HybridSequential, Dense, BatchNorm, Activation, Conv1D, MaxPool1D, Dropout, Flatten
from mxnet.gluon.rnn import LSTM
from mxnet.gluon.loss import SoftmaxCrossEntropyLoss, SigmoidBinaryCrossEntropyLoss
from mxnet.gluon.data import ArrayDataset
from mxnet.io import NDArrayIter
from time import time
import sys


def createSimpleModel(num_classes):
    model = mx.gluon.nn.HybridSequential(prefix='model_')
    with model.name_scope():            
        model.add(Dense(1800))
        model.add(Activation("relu"))
        model.add(BatchNorm())
        model.add(Dropout(0.25))

        for i in range(7):
            model.add(Dense(400))
            model.add(Activation("relu"))
            model.add(BatchNorm())
            model.add(Dropout(0.25))

        model.add(Dense(num_classes))
        model.add(BatchNorm())

    return model
            

class ModelTrainer(object):
    def __init__(self):
        self.loss = SoftmaxCrossEntropyLoss(sparse_label=False)

    def trainModel(self, dataset_builder, batch_size, learning_rate, epochs, num_gpu=1):
        def forward_backward(model, data, labels):
            with autograd.record():
                losses = [self.loss(model(X), Y) for X, Y in zip(data, labels)]
            for l in losses:
                l.backward()

        def train_batch(data, labels, ctx, model, trainer):
            size = data.shape[0]
            # split the data batch and load them on GPUs
            data = gluon.utils.split_and_load(data, ctx, even_split=False)
            labels = gluon.utils.split_and_load(labels, ctx, even_split=False)

            forward_backward(model, data, labels)
            # update parameters
            trainer.step(size)

        def valid_batch(data, labels, ctx, model):
            data = data.as_in_context(ctx[0])
            pred = np.argmax(model(data).asnumpy(), axis=1)
            return sum([l[pred[i]] for i, l in enumerate(labels.asnumpy())])
        
        def valid_all(data_loader, ctx, model):
            correct, num = 0.0, 0.0
            for (data, labels) in data_loader:
                correct += valid_batch(data, labels, ctx, model)
                num += data.shape[0]
            return correct / num

        model = createSimpleModel(dataset_builder.Y_train.shape[1])
        model.hybridize()
        model.initialize()

        print('Batch size is {}'.format(batch_size))
            
        # training on multiple GPUs
        batch_size *= num_gpu
        ctx = [gpu(i) for i in range(num_gpu)]
        print('Running on {}'.format(ctx))
        model.collect_params().initialize(mx.init.Xavier(), force_reinit=True, ctx=ctx)
        
        # data iterator        
        train_data = mx.gluon.data.DataLoader(ArrayDataset(dataset_builder.X_train, dataset_builder.Y_train), batch_size, shuffle=True, last_batch='rollover')
        train_valid_data = mx.gluon.data.DataLoader(ArrayDataset(dataset_builder.X_train, dataset_builder.Y_train), batch_size, sampler=mx.gluon.data.RandomSampler(min(25000, len(dataset_builder.X_train))), shuffle=False, last_batch='keep')
        valid_data = mx.gluon.data.DataLoader(ArrayDataset(dataset_builder.X_valid, dataset_builder.Y_valid), batch_size, sampler=mx.gluon.data.RandomSampler(min(25000, len(dataset_builder.X_valid))), shuffle=False, last_batch='keep')
        trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': learning_rate * num_gpu})
        for epoch in range(epochs):
            # train
            start = time()
            for (data, labels) in train_data:
                train_batch(data, labels, ctx, model, trainer)

            nd.waitall()  # wait until all computations are finished to benchmark the time
            print('Epoch %d, training time = %.1f sec' % (epoch, time() - start))

            # validating
            print('         train accuracy = %.4f' % valid_all(train_valid_data, ctx, model))
            print('         test accuracy = %.4f' % valid_all(valid_data, ctx, model))
            sys.stdout.flush()

modelTrainer = ModelTrainer()

for i in range(1,9):
    print("Training on {} GPUs".format(i))
    modelTrainer.trainModel(dataset_builder=db, batch_size=400, learning_rate=0.01, epochs=10, num_gpu=i)

It fails when training on a single GPU

We just trained the exact same model in Keras/Tensorflow on a single GPU - it is able to handle 10000 samples per batch just fine (runs out of resources with 20000). Is there something obviously wrong in the code above?

An update:

we upgraded MXNet from 0.11 to 0.12 and tried training the same model - this time it was able to handle batch sizes up to 1000

Have you tried to use the hybridize() and Hybrid blocks? In general, the declarative programming can not theoretically for general models be very memory efficient.

A very good point, @botev. In fact, we did switch to the HybridSequential model and hybridized it when we tested on MXNet 0.12. And it could very well be why we managed to get to batch sizes up to 1000 (and not because of 0.12 upgrade). But we are still way behind Keras/Tensorflow where we could scale to batch sizes of up to 15000.

Are you seeing anything obviously wrong with the code?

I don’t see any explicit issue with the code. Note that however, I have never used MXNet so far so I’m quite the newbie. Also, note that you need to call hybridize() explicitly to gain the benefits of the Hybrid Blocks. If the issue remains I would personally raise an issue with on GitHub for the guys responsible for the memory optimizer as this seems like a very easy thing to optimize which is not happening.

The code above does have an explicit hybridize() call. @piiswrong - do you have any additional concerns with the code? Is there any ballpark formula to estimate memory usage of the forward and backward passes as a function of input dimensions (batch size and single sample size) and the number of params? At least for a simple multi-layer perceptron…

Again, thanks for helping!

For a forward net like yours, the “optimal” memory is the memory for your parameters + N * D, where N is the batch size and D is maximum dimensionality across your layers. The worst case I think is parameters memory + N * C * (sum D), where C = 2 or 3.

Yep, I also thought along the same lines… Still struggling to find an explanation for why it is exceeding 11Gb per GPU and running out of memory. Should be in hundreds of Mbs in the worst case

What is the purpose of the following line:
for i in range(1,9):

To sequentially train on num_gpu=1…8. Could be removed and the number of GPUs hardcoded

Interestingly, what I found out is that the GPU memory is not freed up between iterations, but is freed up after each epoch. So, for large datasets we quickly run out of GPU memory. Is there something to do with the DataLoader/iterator?

It’s because you didn’t synchronize in each iteration.
You can move mx.nd.waitall() into the for (data, labels) in train_data: loop

1 Like

Thanks a ton, works on large datasets now

1 Like

Sorry I have another question related to your post. Why do you use batch norm after the activation?

Good point! An oversight and a copy-and-paste error - thanks for pointing out. The real (non-test) model has it in the correct order

1 Like