HybridBlock,How does he add more input to this hybrid_forward function?

Hi,I want to here add I want to add more input to hybrid_forward…
Code:
class UNet(nn.HybridBlock):
def init(self,in_channels,out_classes):
super(UNet,self).init()
with self.name_scope():
self.inc=nn.HybridSequential(prefix=’’)
self.down1=nn.HybridSequential(prefix=’’)
self.down2=nn.HybridSequential(prefix=’’)
self.down3=nn.HybridSequential(prefix=’’)
self.down4=nn.HybridSequential(prefix=’’)
self.up1=nn.HybridSequential(prefix=’’)
self.up2=nn.HybridSequential(prefix=’’)
self.up3=nn.HybridSequential(prefix=’’)
self.up4=nn.HybridSequential(prefix=’’)
self.out=nn.HybridSequential(prefix=’’)
#开始输入层模块
self.inc.add(InConv(in_channels=in_channels,out_channels=8))
#向下提取特征层 即下采样
self.down1.add(DownPool(in_channels=8,out_channels=16))
self.down2.add(DownPool(in_channels=16,out_channels=32))
self.down3.add(DownPool(in_channels=32,out_channels=64))
self.down4.add(DownPool(in_channels=64,out_channels=128))
#下采样最后输出的是1024
#上采样
self.up1.add(UpConv(in_channels=128))
self.up2.add(UpConv(in_channels=64))
self.up3.add(UpConv(in_channels=32))
self.up4.add(UpConv(in_channels=16))
#上采样最后的输出的是64
#最后输出层模块
self.out.add(OutConv(out_classes=out_classes))

def hybrid_forward(self,F,x):
    #按照U-Net顺序组合起来
    x1=self.inc(x)
    x2=self.down1(x1)
    x3=self.down2(x2)
    x4=self.down3(x3)
    x5=self.down4(x4)
    x=self.up1(x5,x4)
    x=self.up2(x,x3)
    x=self.up3(x,x2)
    x=self.up4(x,x1)
    return self.out(x)

#向上采样层 反卷积
class UpConv(nn.HybridBlock):
def init(self,in_channels):
super(UpConv,self).init()
with self.name_scope():
self.up_conv=nn.HybridSequential(prefix=’’)
self.conv=nn.HybridSequential(prefix=’’)
#添加转置卷积 即反卷积 向上采样
self.up_conv.add(nn.Conv2DTranspose(channels=in_channels//2,kernel_size=30,strides=2, padding=14))
self.conv.add(DoubleConv(in_channels//2,in_channels//2))

def hybrid_forward(self,F,x,y):
    out = []
    for block in self._children.values():
        #上采样放大
        x=self.up_conv(x)
        out.append(block(x, y))      
    return self.conv(nd.concat(*out,dim=1))

I am sorry to report this error.

Could you tell me how to solve this problem? Thanks

Hi @DarkWings,

I see you are trying to use UNet, and this is why you need more input in the hybrid_forward, is this correct?
In general it is straightforward to add additional arguments and use them, see this link and this for a UNet. However, in particular for unet I don’t see a benefit as it is just a concat operation you are after. So this UNet should be a good starting point for you (I haven’t debugged the code, but looks OK, at least for a starting point. You may want to add BatchNorm after Conv2DTranspose layers too. )?

import mxnet as mx
from mxnet import gluon 
from mxnet.gluon import HybridBlock # This is for imperative programming. Change to HybridBlock 


class UNetBlock (HybridBlock):
    def __init__(self, Nfilters, **kwargs):
        super(UNetBlock,self).__init__(**kwargs)


        with self.name_scope():
            self.act = gluon.nn.Activation('relu')


            self.conv1 = gluon.nn.Conv2D(Nfilters,kernel_size=3,padding=1,use_bias=False)
            self.bn1 = gluon.nn.BatchNorm(axis=1)
            self.conv2 = gluon.nn.Conv2D(Nfilters,kernel_size=3,padding=1,use_bias=False)
            self.bn2 = gluon.nn.BatchNorm(axis=1)


    def hybrid_forward(self,F,x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)

        return x



class UNet(HybridBlock):
    def __init__(self,  NClasses, Nfilters_init=32, **kwargs):
        super(UNet, self).__init__(**kwargs)
        
        
        with self.name_scope():

            # A single pool is enough, since it doesn't have parameters. 
            self.pool = gluon.nn.MaxPool2D(pool_size=2,strides=2)

            # 32
            self.block1 = UNetBlock(Nfilters_init)
            
            # 64
            self.block2 = UNetBlock(Nfilters_init*2)
            
            # 128 
            self.block3 = UNetBlock(Nfilters_init*2**2)
            
            # 256
            self.block4 = UNetBlock(Nfilters_init*2**3)
            
            # 512
            self.block5 = UNetBlock(Nfilters_init*2**4)

            
            
            # 256
            self.up6 = gluon.nn.Conv2DTranspose(Nfilters_init*2**3,kernel_size=(2,2), strides=(2,2),activation='relu') 
            self.block6 = UNetBlock(Nfilters_init*2**3)
            
            
            # 128 
            self.up7 = gluon.nn.Conv2DTranspose(Nfilters_init*2**2,kernel_size=(2,2), strides=(2,2),activation='relu')
            self.block7 = UNetBlock(Nfilters_init*2**2)
            
            # 64 
            self.up8 = gluon.nn.Conv2DTranspose(Nfilters_init*2**1,kernel_size=(2,2), strides=(2,2),activation='relu')
            self.block8 = UNetBlock(Nfilters_init*2)
            
            

            # 32 
            self.up9 = gluon.nn.Conv2DTranspose(Nfilters_init*2**0,kernel_size=(2,2), strides=(2,2),activation='relu')
            self.block9 = UNetBlock(Nfilters_init)
            
            self.convLast = gluon.nn.Conv2D(NClasses,kernel_size=(1,1),padding=0)
            
            
            
    def hybrid_forward(self,F,x):
        
        conv1 = self.block1(x)
        pool1 = self.pool(conv1)
            
        conv2 = self.block2(pool1)
        pool2 = self.pool(conv2)
 

        conv3 = self.block3(pool2)
        pool3 = self.pool(conv3)


        conv4 = self.block4(pool3)
        pool4 = self.pool(conv4)

        conv5 = self.block5(pool4)

       
        # UpSampling with transposed Convolution
        conv6 = self.up6(conv5)
        conv6 = F.concat(conv6,conv4)
        conv6 = self.block6(conv6)


        # UpSampling with transposed Convolution
        conv7 = self.up7(conv6)
        conv7 = F.concat(conv7,conv3)
        conv7 = self.block7(conv7)


        # UpSampling with transposed Convolution
        conv8 = self.up8(conv7)
        conv8 = F.concat(conv8,conv2)
        conv8 = self.block8(conv8)

        # UpSampling with transposed Convolution
        conv9 = self.up9(conv8)
        conv9 = F.concat(conv9,conv1)
        conv9 = self.block9(conv9)

        
        final_layer = self.convLast(conv9)
        final_layer = F.softmax(final_layer,axis=1)
        
        return final_layer


1 Like

Thank you for your help, my friend. It gave me a big hint. Now I have fixed it. Thanks

1 Like