Any example show us how to train the FCN/PSP/DeepLab segmentation network by custom data?

As the title mentioned, any examples? Or we need to dive into the source codes and figure out how to do it by ourselves? Thanks

Hi,

there are a number of tutorials for the networks you mentioned using gluon-cv.

See the link below:

https://gluon-cv.mxnet.io/build/examples_segmentation/index.html

If you want examples on doing preprocessing for a custom dataset with a model from the model zoo, there’s a tutorial for using the pikachu dataset. Although, it’s for object detection, the code path for cleaning and loading data into the model is similar for segmentation:

https://gluon-cv.mxnet.io/build/examples_detection/finetune_detection.html

And here’s a link to the api docs on the methods available for image preprocessing, transforms, and augmentation on mxnet gluon: https://mxnet.apache.org/api/python/gluon/data.html

And a tutorial for that is here: https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/data_augmentation.html

Hope these help

I have read those examples, problem is none of them show you how to prepare the data needed by the examples of segmentation, all I could do is study the codes of “get_segmentation_dataset” to figure out how to adjust the codes of train.py.

I will write a blog to show the other users how to train with custom dataset after I figure out how to do it.

ps : I think the examples of “train with custom data” are much more valuable than “train on famous dataset” if you are not doing academic research.

Hi @stereomatchingkiss,

would you like to describe - step by step - where exactly are you having difficulty? I maybe can help. For a semantic segmentation problem the pipeline - more or less - is something like this:

  1. Construct pairs of images, masks (masks can be in 1hot encoding, or not)
  2. Decide on a data augmentation methodology (problem dependent).
  3. Select an architecture (problem dependent)
  4. Select a loss function appropriate for semantic segmentation (problem dependent).
  5. train your network (more or less standard code).

If you can be a bit more specific, where you are having difficulty/what type of application you want, perhaps I can help.

Different datasets have different peculiarities in pre-processing, so there is no all around methodology for this. E.g. in pre-processing large raster files (GeoTiff etc) there are two approaches, either read the file and slice on the fly as you train, or slice once and store the chips. Other datasets, e.g. biomedical imaging, have images/masks of different sizes and a significant pre-processing needs to take place to bring them in a fixed format (or you decide to re-size on the fly during training, as a transform, it really depends on what you want to do, and what is more efficient for your application). Kaggle competitions are a good resource for data pre-processing (which is framework independent). You can find a lot of useful kernels there.

All the best,
Foivos

Thanks, I am trying to use the codes of the example “Train FCN on Pascal VOC Dataset” to train with custom dataset, my questions are

  1. How should I generate the mask.

According to the source codes of pascal_voc/segmentation.py, I can create the mask as color image or gray image, as long as the pixels value for each categories are unique.

If I use color image as mask, it could be
person == 255,0,255, bicycle == 0,255,255

If I use gray image as mask, it could be
person == 0
bicycle == 1

  1. How do you map the pixel value to label in train.py?

Like person == 255,0,255 map to channel 0, bicycle == 0,255,255 map to channel 1?
How could I find out they map person or bicycle to which channels?

  1. What is the input of MixSoftmaxCrossEntropyLoss should be?

When aux is true.

The output of the model is an NDArray with shape like (1, 21, 400, 400), the number from left to right is batch size, channel size(same as the type of the pixels), width and height.

The target is an NDArray with shape like (1, 400, 400), the number from left to right is batch size, width and height.

Hi @stereomatchingkiss,

I am currently writing a blog post I think it’ll be finished in a couple of days (ping @ThomasDelteil) . Please add as many comments/questions you may have, I’ll try and include as many as I can for this.

For semantic segmentation problems you will probably have sets of images and annotated masks, i.e. for each img, the mask will have exactly the same resolution as the image and each pixel will contain some “class values”. What is important here is to distinguish between how the mask is given to you (before preprocessing) and what is expected from the loss function for training. For the loss function you always translate your masks into 1-hot encoding, assuming your classes are independent (unless you are doing exotic embeddings, this will always be the case).

Let’s start with an example from the isprs competition, where you are given several large raster files (input images) of dimension (5,6k,6k) and corresponding masks.
This is an example of how the data look like (the segmentation mask, is actually from an inference
result from one of the competitors
):


Now in this example, the organizers decided to annotate the segmantation mask, using RGB colors, e.g. Bulding: (0, 0, 255), i.e. blue. Meaning that for each class, they assign a particular color (e.g. blue is for buildings, green for trees etc). In this competition there are a total of 6 classes. Once you get these data, you need to process them to bring them in an appropriate format so as to be able to calculate the loss function. So, for this particular competition, this is how you would create an integer mask. By integer mask I mean a large 2D raster file where each class corresponds to an integer.

# These are our classes in RGB format, there exist also the corresponding integer values commented out.
# ******************************************************************
Background = np.array([255,0,0]) #:{'name':'Background','cType':0},
ImSurf = np.array ([255,255,255])# :{'name':'ImSurf','cType':1},
Car = np.array([255,255,0]) # :{'name':'Car','cType':2},
Building = np.array([0,0,255]) #:{'name':'Building','cType':3},
LowVeg = np.array([0,255,255]) # :{'name':'LowVeg','cType':4},
Tree = np.array([0,255,0]) # :{'name':'Tree','cType':5}
# ******************************************************************

This is how you go FAST, using numpy, from an RGB mask to integer mask:

def rgb_to_2D_label(_label):
    """
    Here _label is the mask raster that corresponds to the input image.
    """
    label_seg = np.zeros(_label.shape[1:],dtype=np.uint8)
    label_seg [np.all(_label.transpose([1,2,0])==Background,axis=-1)] = 0
    label_seg [np.all(_label.transpose([1,2,0])==ImSurf,axis=-1)] = 1
    label_seg [np.all(_label.transpose([1,2,0])==Car,axis=-1)] = 2
    label_seg [np.all(_label.transpose([1,2,0])==Building,axis=-1)] = 3
    label_seg [np.all(_label.transpose([1,2,0])==LowVeg,axis=-1)] = 4
    label_seg [np.all(_label.transpose([1,2,0])==Tree,axis=-1)] = 5
    
    return label_seg

Note: after you apply this function to the mask it becomes a 2D raster (6k,6k) with integer values (where class car is 2, class building is 3 and so on).

If you wanted to translate this to 1hot encoding directly, this is how you’d go:

# translates image to 1H encoding
def rgb_to_1Hlabel(_label):
    teye = np.eye(NClasses,dtype=np.uint8)
    
    label_seg = np.zeros([*_label.shape[1:],NClasses],dtype=np.uint8)
    label_seg [np.all(_label.transpose([1,2,0])==Background,axis=-1)] = teye[0]
    label_seg [np.all(_label.transpose([1,2,0])==ImSurf,axis=-1)] = teye[1]
    label_seg [np.all(_label.transpose([1,2,0])==Car,axis=-1)] = teye[2]
    label_seg [np.all(_label.transpose([1,2,0])==Building,axis=-1)] = teye[3]
    label_seg [np.all(_label.transpose([1,2,0])==LowVeg,axis=-1)] = teye[4]
    label_seg [np.all(_label.transpose([1,2,0])==Tree,axis=-1)] = teye[5]
    
    return label_seg.transpose([2,0,1])

Once you do this, you need to chop this in small chips so as to feed into the network, like small 256x256 pairs of (images, masks). Let’s look at this example (for simplicity, assume a single image, without batch):


from left to right: input image, digital elevation map (height map), ground truth mask, prediction. We are interested in the 1st (input image that is taken by the network) and the 3rd image (ground truth mask). In this particular example we have 6 classes. Therefore the segmentation mask, once processed for input in the loss function (i.e. in 1hot representation), will be an image of size (6,256,256). The first dimension is the channel axis. Each channel is a binary mask of where in the image existss the class you are after. So the mask[0,:,:] is the binary mask of the first class, mask[1,:,:] is the binary mask of he 2nd class and so on. You can always go back to the integer representation with integer_mask = np.argmax(_1hot_mask,axis=0)

I guess it makes more sense to look at this mask in a different way, let me visualize the various binary masks for you:






So to summarize, assuming you have a problem with NClasses = 6 classes, the last layer of your model must be in dimensions (batch_size, 6, 256,256). This is a typical UNet that applies softmax in the last layer, in order to bring it in a format suitable for comparison with 1hot labels.


import mxnet as mx
from mxnet import gluon
from mxnet.gluon import HybridBlock # This is for imperative programming. Change to HybridBlock 


class UNetBlock (HybridBlock):
    def __init__(self, Nfilters, **kwargs):
        super(UNetBlock,self).__init__(**kwargs)


        with self.name_scope():
            self.act = gluon.nn.Activation('relu')


            self.conv1 = gluon.nn.Conv2D(Nfilters,kernel_size=3,padding=1,use_bias=False)
            self.bn1 = gluon.nn.BatchNorm(axis=1)
            self.conv2 = gluon.nn.Conv2D(Nfilters,kernel_size=3,padding=1,use_bias=False)
            self.bn2 = gluon.nn.BatchNorm(axis=1)


    def hybrid_forward(self,F,x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)

        return x



class UNet(HybridBlock):
    def __init__(self,  NClasses, Nfilters_init=32, **kwargs):
        super(UNet, self).__init__(**kwargs)


        with self.name_scope():

            # A single pool is enough, since it doesn't have parameters. 
            self.pool = gluon.nn.MaxPool2D(pool_size=2,strides=2)

            # 32
            self.block1 = UNetBlock(Nfilters_init)

            # 64
            self.block2 = UNetBlock(Nfilters_init*2)

            # 128 
            self.block3 = UNetBlock(Nfilters_init*2**2)

            # 256
            self.block4 = UNetBlock(Nfilters_init*2**3)

            # 512
            self.block5 = UNetBlock(Nfilters_init*2**4)



            # 256
            self.up6 = gluon.nn.Conv2DTranspose(Nfilters_init*2**3,kernel_size=(2,2), strides=(2,2),activation='relu')
            self.block6 = UNetBlock(Nfilters_init*2**3)


            # 128 
            self.up7 = gluon.nn.Conv2DTranspose(Nfilters_init*2**2,kernel_size=(2,2), strides=(2,2),activation='relu')
            self.block7 = UNetBlock(Nfilters_init*2**2)

            # 64 
            self.up8 = gluon.nn.Conv2DTranspose(Nfilters_init*2**1,kernel_size=(2,2), strides=(2,2),activation='relu')
            self.block8 = UNetBlock(Nfilters_init*2)



            # 32 
            self.up9 = gluon.nn.Conv2DTranspose(Nfilters_init*2**0,kernel_size=(2,2), strides=(2,2),activation='relu')
            self.block9 = UNetBlock(Nfilters_init)

            self.convLast = gluon.nn.Conv2D(NClasses,kernel_size=(1,1),padding=0)


    def hybrid_forward(self,F,x):

        conv1 = self.block1(x)
        pool1 = self.pool(conv1)

        conv2 = self.block2(pool1)
        pool2 = self.pool(conv2)

        conv3 = self.block3(pool2)
        pool3 = self.pool(conv3)

        conv4 = self.block4(pool3)
        pool4 = self.pool(conv4)

        conv5 = self.block5(pool4)

        # UpSampling with transposed Convolution
        conv6 = self.up6(conv5)
        conv6 = F.concat(conv6,conv4)
        conv6 = self.block6(conv6)

        # UpSampling with transposed Convolution
        conv7 = self.up7(conv6)
        conv7 = F.concat(conv7,conv3)
        conv7 = self.block7(conv7)

        # UpSampling with transposed Convolution
        conv8 = self.up8(conv7)
        conv8 = F.concat(conv8,conv2)
        conv8 = self.block8(conv8)

        # UpSampling with transposed Convolution
        conv9 = self.up9(conv8)
        conv9 = F.concat(conv9,conv1)
        conv9 = self.block9(conv9)


        final_layer = self.convLast(conv9)
        # @@@@@@@@@ ATTENTION ON AXIS OF SOFTMAX @@@@@@
        final_layer = F.softmax(final_layer,axis=1)
        # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
        return final_layer

The output of your model now is (batch_size, NClasses, 256, 256). Each channel now represents a probability heat map of the pixel belonging to the specific class. This is another example on how this looks like(third image is a heat map, i.e. the softmax output of the last layer of the network):




and so on. If you threshold the heatmap, i.e. accept every pixel > 0.5 (or some other threshold) to belong in that class, you end up with a hard class label. Another thing to note, as a sanity check, is that if you sum along the channel axis all pixels will give you 1, i.e. the sum of “probabilities” of belonging to some class.

I think this sums up also the 2nd question you may have, for the third, this network uses two loss functions in two different depths of the network, and this is what aux=True means. So in the example, the network outputs two layers, and each goes into the loss. I wouldn’t worry at this stage for this. Try to use the network I’ve given you here, as a starter, which is much simpler.

As for the loss function, a trivial test case would be to translate the (batch, NClasses, 256, 256) predictions to 1D vectors (batch, …), do the same for the ground truth labels and then use binary cross entropy (after all they are all 1’s and 0’s) as a loss. In practice this doesn’t work well, because usually the classes are imbalanced.

edit: This loss function will get you started. It is the Jaccard loss, not the best as it doesn’t treat class imbalance, but it is simple enough to make sense.

from mxnet.gluon.loss import Loss



class jaccard(Loss):
    """
    Jaccard loss coefficient. Adopted from tensorlayer: 
    https://github.com/zsdonghao/tensorlayer/blob/master/tensorlayer/cost.py
    INPUT:
        tensor of size (Nbatch, Nclasses, W,H)
    OUTPUT:
          The average (over batch) of the average value (over classes). 
    """

    def __init__(self, _smooth=1.0e-3, _axis=[2,3], _weight = None, _batch_axis= 0, **kwards):
        Loss.__init__(self,weight=_weight, batch_axis = _batch_axis, **kwards)

        with self.name_scope():
            self.smooth = _smooth
            self.axis = _axis


    def hybrid_forward(self,F, _prediction , _label):

        itrs = F.sum( F.broadcast_mul( _label, _prediction) , axis=self.axis)
        l = F.sum(  F.broadcast_mul(_label , _label), axis=self.axis)
        r = F.sum( F.broadcast_mul( _prediction , _prediction ) , axis=self.axis)

        IoU = (2.0*itrs+self.smooth)/(l+r+self.smooth)

        return F.mean(IoU,axis=1)

Let me get back to the blog post now, where I’ll write more for loss functions etc. Please let me know if you have more questions. I hope the above helps.

All the best,
Foivos

2 Likes

Thanks, never imagine you would write down so much details, I know this could took many hours because I have used pytorch to implement linkNet before and written a blog about it too, the way I create the mask is very slow because I did not utilize the power of vectorize, thanks for the vectorize solution.

The other questions are

  1. What is aux doing? I noticed the outputs of the FCN model output a tuple with two NDArray, each of them got the same size if I set aux as false, how should I deal with them?If I set aux as true, it only output one NDArray.

  2. I know how to generate the mask and write the loss function from scratch, but I would like to know how could I reuse thetrain.py provided by the examples of gluoncv. I would like to know how should I replace the get_segmentation_dataset function in the train.py, and how do the train.py map each pixel categories to the channels. The examples provided by the gluoncv are valuable, but they are harder to reuse, I spend lot of hours to figure out how to train yolo v3 by custom data.

  3. Do I need to keep the aux with the same value when training and predict?

ps : I haven’t studied the paper of PSPNet yet, maybe I would know what is aux doing after I read the paper, but it would be better if you could give us a higher level explain about aux.

Edit : I do not see the segmentation.py map color values like [244,244,0] to single integer value like 0, 1, 2, I saw it read the mask as gray image directly

Edit : I do not see the train.py handle the problem of imbalance pixels, but I saw the options of aux_weight, what is that for?

2 Likes

Hi @stereomatchingkiss,

the PSPNet has output two layers, that they do the same job, i.e. try to classify the same output. This is evident from the definition of the MixSoftmaxCrossEntropyLoss2D where they use the same label, for both prediction outputs (see the hybrid_forward call). (One of) the reason(s) of using multiple loss functions, at different levels of the network has to do with vanishing gradients - especially when the goal is the same in both tasks (see An Overview of Multi-Task Learning for Deep Learning for an overview on multitasking). If you look at this slide, you’ll notice that each vector (which represents the gradient flow) becomes dashed at some point. This represents the fact that if the network is too deep, the terms in the partial derivative chain rule get greater and the gradients diminish. This is like trying to calculate the exponential (something_small^n). The larger the n the smaller the output number, assuming something_small < 1. According to the chain rule, the more layersr you have in the network, the larger is the number of the products in the partial derivative. Then by using complementary losses at different depths “keeps the gradients flowing”. Caution needs to be taken on how you combine different loss functions (hence the term self.aux_weight):

 def _aux_mixup_forward(self, F, pred1, pred2, label1, label2, lam):
        """Compute loss including auxiliary output"""
        loss1 = self._mixup_forward(F, pred1, label1, label2, lam)
        loss2 = self._mixup_forward(F, pred2, label1, label2, lam)
        return loss1 + self.aux_weight * loss2
  1. get_segmentation_dataset: If you look at the definition in the source code, you will see that this function only returns a predefined dataset. So you just need to create your own custom dataset (subclass the gluon.data.DataSet class), you can find an example here and you can see an analog using hdf5 files here on how to create your own dataset based on a set of (images, masks). For your custom dataset. Then, as long as the output of the model has the correct number of channels (i.e. number of classes), the algorithm will learn on it’s own the mapping of each of the channels to a specific class. That is, it will learn for example that the mask car is always in the first channel (if in the ground truth in the 1hot representation this is the case).

  2. I understand that it is difficult to port the examples, I’ve crossed the same path many times. My best advice - from my experience - is to try and write everything on your own, using the examples as a road map. It is time consuming, I know, but I am afraid I cannot help more on this as it requires a substantial portion of my time (have duties at work).

The loss function is not used during inference, therefore I do not understand this question?

Hope the above help,
All the best

2 Likes

Hi, on the first, do you have a link on the segmentation.py source code?

  1. The source code train.py has in lines 145–147:
 # create criterion
criterion = MixSoftmaxCrossEntropyLoss(args.aux, aux_weight=args.aux_weight)
self.criterion = DataParallelCriterion(criterion, args.ctx, args.syncbn)

so the variable args.aux_weight has the class weighting (I don’t know where it gets it from).

edit: this is a mistake, I looked again at the source code, the variable args.aux_weight is a contant, 0.5, it is not weight for balancing the different classes. It is used for combining the different losses giving higher priority (i.e. higher gradient value) to the first loss. I cannot comment whether this is good or bad.

With regards to class imbalance, I found the following weighting scheme works well from Crum et al. 2006, especially the inverse volume. This is described in detail in Sudre et al. 2017, and this is a mxnet implementation of it. Caution: (the weighting is good, the loss function not the best, I have a paper under internal review on this, once is out I will upload on arxiv and link here loss functions for SemSeg):

from mxnet.gluon.loss import Loss

class GDCoeff(Loss):
    """
    Generalized Dice coefficient (Sudre et.al. 2017) for the case of multiclass problems. 
    There are variations in the definition across the literature. Some authors calculate a single IoU from the whole batch by summing also in batch axis. 
    I prefer to get the mean over batches, and return (eventually) the average GDCoeff over all classes, per image. But see source code for details 
    (I may forgot to change the docs here). 
    """
    def __init__(self, _smooth=1.0e-3, _axis=[2,3], _weight = None, _batch_axis= 0, **kwards):
        Loss.__init__(self,weight=_weight, batch_axis = _batch_axis, **kwards)

        self.axis = _axis
        self.smooth = _smooth

    def hybrid_forward(self,F,_preds, _label):


        # Evaluate the mean volume of class per batch
        Vli = F.mean(F.sum(_label,axis=self._axis),axis=0)
        wli =  1.0/Vli**2 # weighting scheme 

        # ---------------------This line is taken from niftyNet package -------------- 
        # ref: https://github.com/NifTK/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py, lines:170 -- 172  
        # new_weights = tf.where(tf.is_inf(weights), tf.zeros_like(weights), weights)
        # weights = tf.where(tf.is_inf(weights), tf.ones_like(weights) * tf.reduce_max(new_weights), weights)
        # --------------------------------------------------------------------

        # ***********************************************************************************************
        # First turn inf elements to zero, then replace that with the maximum weight value  
        # This produces higher - in general - gradient values, so it helps in training! 
        new_weights = F.where(wli == np.float('inf'), F.zeros_like(wli), wli )
        wli = F.where( wli == np.float('inf'), F.broadcast_mul(F.ones_like(wli),F.max(new_weights)) , wli)
        # ************************************************************************************************

        rl_x_pl = F.sum( F.broadcast_mul(_label , _preds), axis=self.axis)
        rl_p_pl = F.sum(_preds +_label,axis=self.axis)

        gdc = 2.0 * F.sum( F.broadcast_mul(wli , rl_x_pl),axis=1) / F.sum( F.broadcast_mul(wli,(rl_p_pl)),axis=1)

        return gdc # This returns the gdc for EACH data point, i.e. a vector of values equal to the batch size 

All the best,
Foivos

1 Like

Thanks for the explanation of aux, I will study more about it for the details, unless I know the purpose of aux will use more than 1 loss function and the purpose is alleviate the phenomenon of gradient vanish.

Thanks, trying to create a class to read custom data, will write a blog and share the codes if work.

Thanks for your advice and precious times, glad to know I am not the only one who found it is not intuitive to reuse.

In this example, it called the api

output = model.demo(img)

Under the hood, it is doing(SegBaseModel model of gluoncv/model_zoo/segbase.py)

def demo(self, x):
        h, w = x.shape[2:]
        self._up_kwargs['height'] = h
        self._up_kwargs['width'] = w
        pred = self.forward(x)
        if self.aux:
            pred = pred[0]
    return pred

That self.aux looks like related to the aux when training, or they are not related?How should I use the output if it returns two NDArray?

#like this?
pred = (pred[0] + pred[1] ) / 2.0

Yes, at here

Thanks, I could try to combine the weighting logic into the loss function MixSoftmaxCrossEntropyLoss

Hi,

looking at the source code you quoted, the output of the forward is always two layers:

 def base_forward(self, x):
        """forwarding pre-trained network"""
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        c3 = self.layer3(x)
        c4 = self.layer4(c3)
        return c3, c4

therefore, it seems the model always produces a tuple as an output. Then the part:

pred = self.forward(x)
if self.aux:
    pred = pred[0]

is equivalent to:

pred1, pred2 = self.forward(x)
if self.aux:
    return pred1
return pred1, pred2 

i.e. self.aux is a boolean that is used for convenience. I don’t see another special purpose. So, how do you use it? Since, usually, the deeper thet better, use pred2 for the final prediction. In practice you can monitor the outputs of both predictions and keep the one that gives you the best result. edit: I suggest go into the paper and read how they use it, and then decide.

So, I went to look into the segmentation.py file you’ve quoted, in line 30 it demonstrates that it uses the VOCSegmentation dataset. Only thing you need to do to understand how this goes, is to load it. Let’s see what we get:

In [1]: import gluoncv

In [2]: trainset = gluoncv.data.VOCSegmentation(split='val')

In [3]: timg, tlabel = trainset[10]

In [4]: print (timg.shape, tlabel.shape)
(480, 480, 3) (480, 480)

In [5]: timg = timg.asnumpy()

In [6]: tlabel = tlabel.asnumpy()

so the tlabel is a 2D “image” (a matrix) - suspect integer entries. If we look at the unique values:

In [9]: np.unique(tlabel)
Out[9]: array([-1.,  0., 18.], dtype=float32)

If you visualize this is what you get:

In [14]: imshow(tlabel,cmap=cm.Dark2)
Out[14]: <matplotlib.image.AxesImage at 0x7f5a4c7087b8>

In [15]: colorbar()
Out[15]: <matplotlib.colorbar.Colorbar at 0x7f5a4c0cb940>

image
So the ground truth label has an integer representation, and you need to manually to translate this to 1hot representation (mxnet has a convenience function to do so as well).

All the best,
Foivos

2 Likes

It never is only two days, isn’t it? Tutorial Part I

If you want to get a feeling of loss functions for semantic segmentation, check our latest paper on the topic. I will also write a blog about the paper (with things that didn’t work) and share code.

All the best,
Foivos

2 Likes

Thanks, will study them :slight_smile:

1 Like