Nan in loss after several epochs in SemSeg problem


#1

Dear all,

I am having a very weird issue, which is most probably a mxnet bug but I am not sure. Any ideas / pointers to help me identify it most welcome.

I have a semantic segmentation problem. The architecture I am using is a tweaked version of the one in this paper. A visual representation of the architecture is from Fig. 1


where each feature map is a modified ResNet v2 block that has multiple parallel branches with atrous convolutions, like this:

ResNet_atrous_unit(
  (ResBlock1): ResNet_v2_block(
    (BN1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (conv1): Conv2D(None -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (BN2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (conv2): Conv2D(None -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (ResBlock2): ResNet_v2_block(
    (BN1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (conv1): Conv2D(None -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
    (BN2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (conv2): Conv2D(None -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))
  )
  (ResBlock3): ResNet_v2_block(
    (BN1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (conv1): Conv2D(None -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(15, 15), dilation=(15, 15), bias=False)
    (BN2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (conv2): Conv2D(None -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(15, 15), dilation=(15, 15))
  )
  (ResBlock4): ResNet_v2_block(
    (BN1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (conv1): Conv2D(None -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(31, 31), dilation=(31, 31), bias=False)
    (BN2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (conv2): Conv2D(None -> 32, kernel_size=(3, 3), stride=(1, 1), padding=(31, 31), dilation=(31, 31))
  )
)

so the forward call is something like

 def hybrid_forward(self,F,_xl):

        # This output is like out = _xl + sum_i ResBlock_i (_xl)
        x = _xl

        # Uniform description for both Symbol and NDArray
        x = F.broadcast_add( x , self.ResBlock1(_xl) )
        x = F.broadcast_add( x , self.ResBlock2(_xl) )
        x = F.broadcast_add( x , self.ResBlock3(_xl) )
        x = F.broadcast_add( x , self.ResBlock4(_xl) )

        return x

In total I have 16 feature maps like this one and some additional layers for upsampling/max pooling (PSP pooling). The number of filters start from 64 and go up to 1024 progressively.

The loss function is a generalized dice coefficient, from Sudre et al 2017
this is the implementation and it’s been working really fine (taken from the tensorflow implementation of the authors, translated to mxnet).

class GDCoeff(Loss):
    """
    Generalized Dice coefficient (Sudre et.al. 2017) for the case of multiclass problems. 
    There are variations in the definition across the literature. Some authors calculate a single IoU from the whole batch by summing also in batch axis. 
    I prefer to get the mean over batches, and return (eventually) the average GDCoeff over all classes, per image. But see source code for details 
    """
    def __init__(self, _smooth=1.0e-3, _axis=[2,3], _weight = None, _batch_axis= 0, **kwards):
        Loss.__init__(self,weight=_weight, batch_axis = _batch_axis, **kwards)

        self.axis = _axis
        self.smooth = _smooth

    def hybrid_forward(self,F,_preds, _label):


        # Evaluate the mean volume of class per batch
        Vli = F.mean(F.sum(_label,axis=[2,3]),axis=0)
        wli =  1.0/Vli**2 # weighting scheme 

        # ---------------------This line is taken from niftyNet package -------------- 
        # ref: https://github.com/NifTK/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py, lines:170 -- 172  
        # new_weights = tf.where(tf.is_inf(weights), tf.zeros_like(weights), weights)
        # weights = tf.where(tf.is_inf(weights), tf.ones_like(weights) * tf.reduce_max(new_weights), weights)
        # --------------------------------------------------------------------

        # ***********************************************************************************************
        # First turn inf elements to zero, then replace that with the maximum weight value  
        # This produces higher - in general - gradient values, so it helps in training! 
        new_weights = F.where(wli == np.float('inf'), F.zeros_like(wli), wli )
        wli = F.where( wli == np.float('inf'), F.broadcast_mul(F.ones_like(wli),F.max(new_weights)) , wli)
        # ************************************************************************************************



        rl_x_pl = F.sum( F.broadcast_mul(_label , _preds), axis=self.axis)
        rl_p_pl = F.sum(_preds +_label,axis=self.axis)

        gdc = 2.0 * F.sum( F.broadcast_mul(wli , rl_x_pl),axis=1) / F.sum( F.broadcast_mul(wli,(rl_p_pl)),axis=1)

        return gdc # This returns the gdc for EACH data point, i.e. a vector of values equal to the batch size 

In the past I have been running this model for up to 200+ generations (each generation ~25k iterations) without a problem and with several restarts and change of the learning rate. Then I updated mxnet+my python distribution (from 2.7 to 3.6) in the latest versions from pip and at some point, although the training progresses normally, initially I see a sudden drop to the validation loss (from ~ 76% to ~30% which is not justified from experience of previous runs when things were working, overfitting has never given below 70% in the past), and then nan starts appearing in the training loss. The problem is that this doesn’t appear consistently and I am a bit lost on how to debug this thing.

The whole network runs in parallel in 4 GPUs (P100) and I am using a delayed update to the gradients, so as to increase the batch size, i.e.

for param in self.mynet.collect_params().values():
            param.grad_req ='add'

The update happens after 4 aggregations of the gradients with

  def _forward_backward_step(self,_iteration,_data,_label, _trainer=None):

        with autograd.record():
            # First argument is PREDICTIONS, second LABELS 
            losses = [self.loss(self.mynet(inputs), labels) for inputs, labels in zip(_data, _label)]

        # This is o utside the autograd.record state 
        for l in losses: # Evaluate gradients in each ctx 
            l.backward()


        # This updates gradients across ALL devices, by first aggregating them. <3 Gluon!
        if (_iteration % self.config[C.C_UPDT_DELAY_RATE] == 0):

            if (_trainer == None):
                self.trainer.step(self.config[C.C_BATCH_SIZE] * self.config[C.C_UPDT_DELAY_RATE]) 
            else :
                _trainer.step(self.config[C.C_BATCH_SIZE] * self.config[C.C_UPDT_DELAY_RATE])  

            for param in self.mynet.collect_params().values():
                param.zero_grad()

        return losses

This is an example of training GDC (it needs to be increasing, best 1., the loss is 1-GDC)

I exclude the nan appearing due to empty/corrupted masks, since all images have masks (with data augmentation using skimage), also the nans start appearing after the algorithm has started converging, usually after some restarts. In the restarts I am not using saved instances of the optimizer. The optimizer is adam, with lr = 0.001 (or less) and everything else in default values. The runs are on an HPC cluster, with python3.6, and 4 Tesla P100 GPUs. mxnet version is cu90, v1.2.0 (I’ve also verified the problem with cu91).

Any ideas where / how to look/debug? Many thanks. Can this be due to gradients exploding due to the many layers of the network? Due to the delayed update where I manually aggregate them? But then again, I didn’t had this problem in the past.

Thank you very much for your time.
Foivos


#2

Hi @feevos,

Thanks for the detailed post. I haven’t read yet the paper. A few generic pieces of advice:

  • If possible, roll back your python version and mxnet changes. Do one at a time and see if you find out if the nan losses are correlated with the python upgrade or with the mxnet upgrade.
  • Debug your code using conditional breakpoints for example to understand better what situation lead to your nan losses. You can watch this video from Sina https://www.youtube.com/watch?v=6-dOoJVw9_0 on debugging MXNet code interactively with pycharm.
  • Exploding gradients could lead to the situation you are describing. Have you tried putting your learning rate on a schedule?
  • Is it running on 4 GPU on a single machine or on different machines? Can you share how you save / load the parameters?

Good luck!


#3

Thanks a lot for the detailed info!

  • Learning rate schedhule: no, I mostly train until convergence with fixed LR (adam), and when the training plateaus I restart with a smaller LR /= 10. I have found that restarts help the optimization, even if one keeps the LR fixed - I guess it has to do something with the internal states of the optimizer converging.
  • Single node, 4 GPUs. I am using the example from Gluon for data parallelization as described in the straight dope.

I’ll do some tests and get back, again, many thanks!!


#4

Although I haven’t found what was causing the problem, I reinstalled latest version of mxnet mxnet_cu91-1.2.0b20180501 everything works fine now.

Many thanks again @ThomasDelteil


#5

Glad you solved your issue @feevos ! Good luck with your project.