Gluon models with branches

FYI with your help I was able to quickly put together a shallow implementation uf UNet.

class DoubleConv(nn.HybridBlock):
    def __init__(self, **kwargs):
        super(DoubleConv, self).__init__()
        self.conv_1 = Conv2D(**kwargs)
        self.batchnorm_1 = BatchNorm(axis=1)
        self.conv_2 = Conv2D(**kwargs)
        self.batchnorm_2 = BatchNorm(axis=1)
        
    def hybrid_forward(self, F, x):
        c1 = self.conv_1(x)
        b1 = self.batchnorm_1(c1)
        c2 = self.conv_2(b1)
        b2 = self.batchnorm_2(c2)
        return b2

class Unet(nn.HybridBlock):
    def __init__(self):
        super(Unet, self).__init__()
        self.mp_1 = MaxPool2D(pool_size=(2,2), strides=2)
        self.mp_2 = MaxPool2D(pool_size=(2,2), strides=2)
        self.cv_1 = DoubleConv(channels=16, kernel_size=3, strides=1, padding=1)
        self.cv_2 = DoubleConv(channels=32, kernel_size=3, strides=1, padding=1)
        self.cv_3 = DoubleConv(channels=64, kernel_size=3, strides=1, padding=1)
        self.cv_4 = DoubleConv(channels=32, kernel_size=3, strides=1, padding=1)
        self.cv_5 = DoubleConv(channels=16, kernel_size=3, strides=1, padding=1)
        self.cv_6 = Conv2D(channels=1, kernel_size=3, strides=1, padding=1, activation='tanh')
        self.ct_1 = Conv2DTranspose(channels=32, kernel_size=2, strides=2, activation='relu')
        self.ct_2 = Conv2DTranspose(channels=16, kernel_size=2, strides=2, activation='relu')

    def hybrid_forward(self, F, x):
        l1 = self.cv_1(x)
        m1 = self.mp_1(l1)
        l2 = self.cv_2(m1)
        m2 = self.mp_2(l2)
        l3 = self.cv_3(m2)
        u1 = self.ct_1(l3)
        c1 = F.concat(l2, u1)
        u2 = self.cv_4(c1)
        u3 = self.ct_2(u2)
        u4 = F.concat(l1, u3)
        u5 = self.cv_5(u4)
        u6 = self.cv_6(u5)
        return u6

2 Likes