Training speed in MXNet is nearly 2.5x times slower than Pytorch

Today I started using MXNet’s Gluon.cv imagenet training script. I used the MobileNet1.0 bash config presented here(classification.html).
A single epoch takes more than 2 hours (2hours and 35 minutes! to be exact) to complete!!, while in Pytorch for example, it took around 45 minutes using my GTX1080.
I have a 4790K @4.5Ghz, and a Samsung 840EVO 250G from which I’m reading my training data. I have both CUDA 9.0, and cudnn 7.4 installed and ready.
GPU load is constantly at 99~100%.
The GPU fans are at 45% speed!
8.6G/15.6G of system RAM is used.
And I’m on Ubuntu 16.04.5.

mxnet version :  1.3.1
GluonCV version : 0.4.0
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 390.48                 Driver Version: 390.48                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 1080    Off  | 00000000:01:00.0  On |                  N/A |
|  46%  63C    P2    90W / 200W |   7380MiB /  8116MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      1091      G   /usr/lib/xorg/Xorg                           172MiB |
|    0      2867      G   compiz                                       184MiB |
|    0      15164     C   python                                      7017MiB |
+-----------------------------------------------------------------------------+

Although I can see 20 threads in htop, however, only one thread at a time consumes nearly 100% of CPU time, other threads consume less than 16% (and it gets lesser for each remaining thread)!

by the way this is how I initiated the training :

python train_imagenet.py \
  --rec-train /media/ssd/rec/train.rec --rec-train-idx /media/ssd/rec/train.idx \
  --rec-val /media/ssd/rec/val.rec --rec-val-idx /media/ssd/rec/val.idx \
  --model simpnet1.0 --mode hybrid \
  --lr 0.4 --lr-mode cosine --num-epochs 200 --batch-size 256--num-gpus 1 -j 20 \
  --use-rec --dtype float16 --warmup-epochs 5 --no-wd --label-smoothing --mixup \
  --save-dir params_mobilenet1.0_mixup \
  --logging-file simpnet1.0_mixup.log

Update:
At first I thought maybe this is because of using float16 as the dtype! so I also tried float32 and the problem still persisted! that is, the training performance did not change at all!
Also the way GPU is being utilized is very weird. while it says its 100% under load, the temps never go up beyond 62~64C .This is weird because the fans are also nearly at idle speed. Usually a load of 100% results in 70/72C in temperature.
Also the CPU utilization is weird. it doesn’t matter if I use 4 threads or 20 threads, the CPU utilization is the same almost.
When training in Pytorch, I’d use 20 threads, and all 8 threads were utilized nearly to the max!, and the GPU utilization was between 89~99% and the temp was around 72/74C and each epoch would take around 45 minutes to complete and definitely not nearly 3.44x times more as in mxnet.

I guess there should be a bug somewhere here, this doesn’t make sense to me at all.
Any help is greatly appreciated .
Thanks in advance

Hi @Master,

Just to confirm, you’re doing a like-to-like comparison with PyTorch here? I notice you’re using features like mixup which may improve the classification performance of the model but reduce speed of training. And same batch size, etc? Also, what is simpnet1.0?

Hi, Thank you very much for your response.
Honestly I dont think “mixup” could have such a huge impact as to decrease the performance nearly 3 times! but I pretty much may be wrong.
My comparison concerning the mxnet vx pytorch performance, was solely based on training on imagenet with the same procedures except the mixup( actually I just noticed it!). and yes, with the same batch-size, data-augmentation etc.
As you can see, I neither have a considerable load on my CPU nor GPU, although GPU says 100% under-load, but the GPU temp and fan speed say a different story! I’m genuinely puzzled here!
simpnet definition is given here

Good detective work with the GPU temperatures :slight_smile:

nvidia-smi GPU Utilisation is defined as:

Percent of time over the past sample period during which one or more kernels was executing on the GPU. The sample period may be between 1 second and 1/6 second depending on the product.

So the “one or more kernels” could explain what’s happening here. GPU get their advantage from parallel processing but it looks like you could potentially get 100% GPU Utilisation just by running a single kernel in serial. Something isn’t parallelising I think.

You should use the profiling tool on your code, and the suspect operator should be easily spotted. See this tutorial for an explanation of how to use it. It integrates with Chrome’s tracing tool, if you prefer a nice user interface.

My hunch is that it could be the ‘dropout’ operator. It seems you set the probability of dropout at 0 anyway (i.e. no effect) so you could try removing those unnecessary layers.

Good luck!

3 Likes

Truly a magnificent catch! Thanks a lot :slight_smile:
Dropout(0) was indeed the culprit. After commenting it out, not only it became 6x times faster! the memory consumption also decreased drastically (from 7.xG to 5.xGig)!
I believe this is indeed a bug which needs to be rectified as well.
Checking for a dropout ratio less or equal to 0 should disable the layer or at list avoid such big time slowdown!
As to why I would want to use a dropout of 0, I must say, in order to have the dropout layer in the graph, so when I intend on resuming, in case dropout is needed, I can simple apply a ratio and resume the training. In Pytorch this was the reason since adding a dropout in other ways was more cumbersome.
Anyway, Thanks a gazillion times again .
Have a wonderful weekend:)

2 Likes

Awesome! Glad it worked. Would you be able to share your final model so I can open a Github Issue for the performance issue you found? I see that most of the code is here but that has some errors in that you’ve fixed. Would be great to get the fixed up version, and the shape of the input you’re using for these models. Cheers, Thom

Sure thing. here you are, this is what thats being trained so far :

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable= arguments-differ,unused-argument,missing-docstring
"""SimpNet , implemented in Gluon."""
__all__ = [
    'SimpNet',
    'simpnet1_0',
    'simpnet0_75',
    'simpnet0_5',
    'simpnet0_25',
    'get_simpnet']

__modify__ = 'dwSun'
__modified_date__ = '18/04/18'

from mxnet.gluon import nn
from mxnet.context import cpu
from mxnet.gluon.block import HybridBlock
from mxnet import gluon


# Helpers
class RELU6(nn.HybridBlock):
    """Relu6 used in SimpNetV2."""

    def __init__(self, **kwargs):
        super(RELU6, self).__init__(**kwargs)

    def hybrid_forward(self, F, x):
        return F.clip(x, 0, 6, name="relu6")


# pylint: disable= too-many-arguments
def _add_conv(out, input_channels, filters=1, kernel=3, stride=1, pad=1,
              active=True, relu6=False, num_sync_bn_devices=-1, dropout=False):

    out.add(nn.Conv2D(in_channels=input_channels,
                      channels=filters,
                      kernel_size=kernel,
                      strides=stride,
                      padding=pad,
                      use_bias=False))

    if num_sync_bn_devices <= 1:
        out.add(nn.BatchNorm(in_channels=filters, scale=True))
    else:
        out.add(gluon.contrib.nn.SyncBatchNorm(scale=True, num_devices=num_sync_bn_devices))
    if active:
        out.add(RELU6() if relu6 else nn.Activation('relu'))
    #if dropout:
        # out.add(nn.Dropout(0))    

# Net
class SimpNet(HybridBlock):
    r"""
    Parameters
    ----------
    multiplier : float, default 1.0
        The width multiplier for controlling the model size. Only multipliers that are no
        less than 0.25 are supported. The actual number of channels is equal to the original
        channel size multiplied by this multiplier.
    classes : int, default 1000
        Number of classes for the output layer.
    num_sync_bn_devices : int, default is -1
        Number of devices for training. If `num_sync_bn_devices < 2`, SyncBatchNorm is disabled.
    """

    def __init__(self, multiplier=1.0, classes=1000, num_sync_bn_devices=-1, network_idx=0, s_mode=2, **kwargs):
        super(SimpNet, self).__init__(**kwargs)

        self.cfg = {
        'simpnet5m': [['C', 66], ['C', 128], ['C', 128], ['C', 128], ['C', 192], ['C', 192], ['C', 192], ['C', 192], ['C', 192], ['C', 288], ['P'], ['C', 288], ['C', 355], ['C', 432]],
        'simpnet8m': [['C', 128], ['C', 182], ['C', 182], ['C', 182], ['C', 182],  ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 430], ['P'], ['C', 430], ['C', 455], ['C', 600]]}
        self.scale = multiplier
        self.networks = ['simpnet5m', 'simpnet8m']
        self.network_idx = network_idx
        self.mode = s_mode
        self.strides = {1: [2, 2, 2, 1, 1],      #s1
                        2: [2, 2, 1, 2, 1, 1],   #s4
                        3: [2, 2, 1, 1, 2, 1],   #s3
                        4: [2, 1, 2, 1, 2, 1],   #s5
                        5: [2, 1, 2, 1, 2, 1, 1]}#s6

        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')
            with self.features.name_scope():

                layers = []
                input_channel = 3
                idx = 0

                for x in self.cfg[self.networks[self.network_idx]]:
                    if idx == len(self.strides[self.mode]) or x[0] == 'P':
                        self.features.add(nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding=0, layout='NCHW', ceil_mode=False))
                        # self.features.add(nn.Dropout(0))

                    if x[0] != 'C':
                        continue
                    filters = round(x[1] * self.scale)
                    if idx < len(self.strides[self.mode]):
                        stride = self.strides[self.mode][idx]
                    else:
                        stride = 1
                    if idx in (len(self.strides[self.mode])-1, 9, 12):
                        _add_conv(self.features, input_channels=input_channel, filters=int(filters), kernel=3, pad=1, stride=stride,
                              num_sync_bn_devices=num_sync_bn_devices)

                    else:
                        _add_conv(self.features, input_channels=input_channel, filters=int(filters), kernel=3, pad=1, stride=stride,
                              num_sync_bn_devices=num_sync_bn_devices, dropout=True)

                    input_channel = filters
                    idx += 1
                print('pkhere!')
                self.features.add(nn.GlobalMaxPool2D())
                # self.features.add(nn.Dropout(0))
                self.features.add(nn.Flatten())

            numx=round(self.cfg[self.networks[network_idx]][-1][1] * self.scale)
            #classifier
            self.output = nn.Dense(in_units=numx, units=classes)

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.output(x)
        return x

# Constructor
def get_simpnet(multiplier, pretrained=False, ctx=cpu(),
                  root='~/.mxnet/models', num_sync_bn_devices=-1, **kwargs):
    r"""
    Parameters
    ----------
    multiplier : float
        The width multiplier for controlling the model size. Only multipliers that are no
        less than 0.25 are supported. The actual number of channels is equal to the original
        channel size multiplied by this multiplier.
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    num_sync_bn_devices : int, default is -1
        Number of devices for training. If `num_sync_bn_devices < 2`, SyncBatchNorm is disabled.
    """
    net = SimpNet(multiplier, num_sync_bn_devices=num_sync_bn_devices, **kwargs)
    if pretrained:
        from .model_store import get_model_file
        version_suffix = '{0:.2f}'.format(multiplier)
        if version_suffix in ('1.00', '0.50'):
            version_suffix = version_suffix[:-1]
        net.load_parameters(get_model_file('SimpNet%s' % version_suffix,
                                           tag=pretrained,
                                           root=root), ctx=ctx)
        from ..data import ImageNet1kAttr
        attrib = ImageNet1kAttr()
        net.synset = attrib.synset
        net.classes = attrib.classes
        net.classes_long = attrib.classes_long
    return net


def simpnet1_0(**kwargs):
    r"""
    Parameters
    ----------
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    num_sync_bn_devices : int, default is -1
        Number of devices for training. If `num_sync_bn_devices < 2`, SyncBatchNorm is disabled.
    """
    return get_simpnet(1.0, **kwargs)



def simpnet0_75(**kwargs):
    r"""
    Parameters
    ----------
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    num_sync_bn_devices : int, default is -1
        Number of devices for training. If `num_sync_bn_devices < 2`, SyncBatchNorm is disabled.
    """
    return get_simpnet(0.75, **kwargs)



def simpnet0_5(**kwargs):
    return get_simpnet(0.5, **kwargs)


def simpnet0_25(**kwargs):
    r"""
    Parameters
    ----------
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    num_sync_bn_devices : int, default is -1
        Number of devices for training. If `num_sync_bn_devices < 2`, SyncBatchNorm is disabled.
    """
    return get_simpnet(0.25, **kwargs)

Just raised a GitHub issue for this: https://github.com/apache/incubator-mxnet/issues/13825.

And thanks for sending through the code! I was able to replicate, and used it to create a minimum reproducible example. :+1:

2 Likes

Hi, just to add some more info into this that it may be relative: syncBatchNorm is - unfortunately - very slow to use (for me, close to the point of unusable). I get great performance with small batch size and stable training but it takes forever (the method of delayed gradient updates with increased batch size is faster - in my experiments). If the pytorch version you are comparing does not use it, there may also be another factor for the performance difference.