FYI with your help I was able to quickly put together a shallow implementation uf UNet.
class DoubleConv(nn.HybridBlock):
def __init__(self, **kwargs):
super(DoubleConv, self).__init__()
self.conv_1 = Conv2D(**kwargs)
self.batchnorm_1 = BatchNorm(axis=1)
self.conv_2 = Conv2D(**kwargs)
self.batchnorm_2 = BatchNorm(axis=1)
def hybrid_forward(self, F, x):
c1 = self.conv_1(x)
b1 = self.batchnorm_1(c1)
c2 = self.conv_2(b1)
b2 = self.batchnorm_2(c2)
return b2
class Unet(nn.HybridBlock):
def __init__(self):
super(Unet, self).__init__()
self.mp_1 = MaxPool2D(pool_size=(2,2), strides=2)
self.mp_2 = MaxPool2D(pool_size=(2,2), strides=2)
self.cv_1 = DoubleConv(channels=16, kernel_size=3, strides=1, padding=1)
self.cv_2 = DoubleConv(channels=32, kernel_size=3, strides=1, padding=1)
self.cv_3 = DoubleConv(channels=64, kernel_size=3, strides=1, padding=1)
self.cv_4 = DoubleConv(channels=32, kernel_size=3, strides=1, padding=1)
self.cv_5 = DoubleConv(channels=16, kernel_size=3, strides=1, padding=1)
self.cv_6 = Conv2D(channels=1, kernel_size=3, strides=1, padding=1, activation='tanh')
self.ct_1 = Conv2DTranspose(channels=32, kernel_size=2, strides=2, activation='relu')
self.ct_2 = Conv2DTranspose(channels=16, kernel_size=2, strides=2, activation='relu')
def hybrid_forward(self, F, x):
l1 = self.cv_1(x)
m1 = self.mp_1(l1)
l2 = self.cv_2(m1)
m2 = self.mp_2(l2)
l3 = self.cv_3(m2)
u1 = self.ct_1(l3)
c1 = F.concat(l2, u1)
u2 = self.cv_4(c1)
u3 = self.ct_2(u2)
u4 = F.concat(l1, u3)
u5 = self.cv_5(u4)
u6 = self.cv_6(u5)
return u6