End-to-End Gluon Model with Skip Connections


#1

I’m trying to make a custom end-to-end model implemented with Gluon and was wondering if it’s possible to use Skip Connections with a pure HybridSequential model.

I’ve created a Concat gluon wrapper HybridBlock for this purpose, but I don’t know how to use it within a HybridSequential properly.

My initial idea was to create a HybridSequential for both the encoder and decoder and the encoder sequential would connect with the decoder sequential model with concat.

class Decoder(HybridBlock):
    def __init__(self):
        self.model = HybridSequential()
        self.model.add(
            # Convolution blocks
        )

    def hybrid_forward(self, f, x):
        return self.model(x)

class Encoder(HybridBlock):
    def __init__(self, decoder_model):
        self.model = HybridSequential()
        self.model.add(
            # Convolution Block
            Concat(decoder_model[-1])
            # etc...
        )

    def hybrid_forward(self, f, x):
        return self.model(x)

class Model(HybridBlock):
    def __init__(self):
        self.model = HybridSequential()
        decoder = Decoder()
        self.model.add(
            decoder,
            Encoder(decoder.model)
        )

    def hybrid_forward(self, f, x):
        return self.model(x)

But when I thought about how this would work during Feed Forward, it seemed like the Encoder would just re-forward through the decoder model blocks.

Any tips on how to move forward would be gladly appreciated.


#2

Here is a simplified example for a model that uses skip connection:

class Example(mx.gluon.nn.HybridBlock):
    def __init__(self, filters):
        super(Example, self).__init__()
        
        self.F1, self.F3 = filters
        with self.name_scope():
            self.part1 = mx.gluon.nn.HybridSequential()
            with self.part1.name_scope():           
                self.part1.add(mx.gluon.nn.Conv2D(channels=self.F1, kernel_size=(1,1), strides=(1,1)))
                self.part1.add(mx.gluon.nn.BatchNorm(axis=1))
                self.part1.add(mx.gluon.nn.Activation(activation='relu'))
        
                self.part1.add(mx.gluon.nn.Conv2D(channels=self.F3,kernel_size=(1,1), strides=(1,1)))
                self.part1.add(mx.gluon.nn.BatchNorm(axis=1))
            
    def hybrid_forward(self, F, X):  
        l1 = self.part1(X)
        return (l1 + X).relu() 

In the forward pass, we feed the data X through some convolution layers l1 = self.part1(X) and we add the output to the data (l1 + X).relu(). The visualization of the network looks like the following:

conv_block  = Example(filters=[64, 256])
symbol_data = mx.sym.var('data')
tmp         = conv_block(symbol_data)

mx.viz.plot_network(tmp)


#3

I’m aware on how to make a single model using skip connections in hybrid_forward. But the custom model I’m using is constructed with modularity in mind, and implemented solely with HybridBlocks/Sequentials. I was wondering if it’s possible to access the value of each block in a HybridSequential afrer feedforward.

In a sense the model has three components: A, B, and C. Each component is essentially a HybridBlock containing a HybridSequential for it’s model. All I want the hybrid_forward to do is pass the input through the HybridSequential, and nothing else if possible.

A --> Decoder (Downsample),
B --> Skip Connections with convolutions
C --> Encoder (Upsample)

The architecture is a mix between a U-Net and a Grid Net.


#4

Hi @zahidislm, you can access the elements as you would in a list, but you need to store in intermediate variables the output since you will need it for concatenation with layers of the decoder. You can store the outputs of the convolutions (or else, building blocks of your unet) in a list inside the hybrid_forward function and use them with indexing. The example below demonstrates how to access the layers from HybridSequential with indexing, but is not a complete implementation with for loops, it will give you the idea to move ahead though. Please let me know if this is what you are after, or you need something else?

class UNet(HybridSequential):
    def __init__(self, **kwards):
        Encoder.__init__(self,**kwards)


       with self.name_scope():
           self.encoder = gluon.nn.HybridSequential()
           # Let's add 3 layers
           for _ in range(3):
               self.encoder.add(gluon.nn.Conv2D(...)) 
               self.encoder.add(gluon.nn.MaxPool2D(..))

         self.decoder = gluon.nn.HybridSequential()
        for _ in range(3-1): # Assuming you are using a UNet, the decoder has -1 layers
            self.decoder.add(gluon.nn.Conv2D(...))

    # Let's assume x.shape -> (batch_size, NChannels, 256,256)
    def hybrid_forward(self,F,x):

        out1_256 = self.encoder[0](x)
        out2_128 = F.Pooling(out1_256, stride=2, kernel = 2) # half the size 
        out3_128 = self.encoder[1](out2_128)
        out4_64 = F.Pooling(out3_128, k=2, s= 2, ...)
        out5_64 = self.encoder[2](out_64) # This is the middle layer

        out5_128 = F.UpSample(out5_64, ...) # UpSample the middle layer, with Conv2DTranspose or interpolation
        out6_128 = F.concat(out5_128, out3_128)
        out6_128 = self.decoder[0](out6_128) 

        out6_256 = F.UpSample(out6_128)
        out7_256 = F.concat([out6_256, out1_256])
        out7_256 = self.decoder[1](out7_256)

        logits = F.Softmax(out7_256,axis=1)  # assuming you are doing semantic segmentation. 
  

Hope this helps.

PS I think with skip connections you do not mean the resnet summation, but the skip connections (concatenations) in the UNet
PPS I define the encoder as the part of the network that leads from higher order features (higher spatial dimensions) to lower order features (lower spatial dimension). So in my pseudocode the [encoder, decoder] corresponds to what you define as [decoder,encoder]


#5

Thank @feevos for the help so far, gladly appreciate it. Essentially, I wanted to see if it was possible to define everything during init as a HybridSequential.

I was wondering if it was possible to somehow use concat within a HybridSequential, instead of doing it in hybrid_forward. I implemented a HybridBlock version of Concat (seen bellow) but I don’t know how to pass multiple args to a HybridSequential during hybrid_forward without doing messy hacks like overriding the base Block class.

class Concat(HybridBlock):
    def __init__(self, dim=1, **kwargs):
        super(Concat, self).__init__(**kwargs)
        self._kwargs = {'dim': dim}

    def hybrid_forward(self, F, *args):
        return F.concat(*args, **self._kwargs)

#6

Hi @zahidislm, I don’t think what you are after is possible, but let’s wait what more experienced users will have to say.

I think the critical point here that does not allow you for the operation you are after, is that inside the def __init__(self,...) function you declare what type of layers you are going to use (but you are not performing any operation with them inside the __init__ function, i.e. you are not using them unless forward is called). In addition, HybridSequential is sequential, which means you can: either use it as a list container (as explained above), or forward through all layers sequentially, one after the other. This is not the case in UNet where you concatenate symmetric layers with respect to the central layer (first with last, second with second-last and so on).

The layers you declare inside the __init__ are being used inside the hybrid_forward function. What I mean by this is that the concatenation you are after is for the output of the layers declared inside the init function. So, if you don’t evaluate explicitely the output, concatenation doesn’t make much sense. To put it in another way:

class SomeNet(HybridBlock):
    def __init__(....):

        with self.name_scope():
             # This here is declared, concist of only the weights and bias as parameters, it doesn't even have the input dimensions that you will apply it to.
            self.net = Conv2D(...) 


    def hybrid_forward(self,F,x):
        # Now here, x has true dimensions, it's true tensor, out1,out2 the same. 
        out1 = self.net(x)
        out2 = self.net(out1)

       # Now it makes sense to concatenate layers that corresponds to full images/filters. 
       out = F.concat(out1,out2,axis=1) 
       .... # do more 

Again, take what I write with a grain of salt, and await for more experienced users to comment.

Regards,
Foivos