Hybridizing GoogLeNet

Hi, I’m trying to adapt the GoogLeNet/InceptionV1 implementation in the online book d2l.ai to be compatible with hybridization. However, I’m currently facing issues with mx.np.concatenate. Here’s a full minimal example with the network implementation:

import d2l # d2l.ai book code
import mxnet as mx
from mxnet import gluon, metric, np, npx
from mxnet.gluon import nn
npx.set_np()

ctx_list = [npx.gpu(i) for i in range(npx.num_gpus())]
mx.random.seed(42, ctx='all')

class Inception(nn.HybridBlock):
    # c1- c4 are the number of output channels for each layer in the path
    def __init__(self, c1, c2, c3, c4, **kwargs):
        super().__init__(**kwargs)
        # Path 1 is a single 1 x 1 convolutional layer
        self.p1_1 = nn.Conv2D(c1, kernel_size=1, activation='relu')
        # Path 2 is a 1 x 1 convolutional layer followed by a 3 x 3
        # convolutional layer
        self.p2_1 = nn.Conv2D(c2[0], kernel_size=1, activation='relu')
        self.p2_2 = nn.Conv2D(c2[1], kernel_size=3, padding=1, activation='relu')
        # Path 3 is a 1 x 1 convolutional layer followed by a 5 x 5
        # convolutional layer
        self.p3_1 = nn.Conv2D(c3[0], kernel_size=1, activation='relu')
        self.p3_2 = nn.Conv2D(c3[1], kernel_size=5, padding=2,
                              activation='relu')
        # Path 4 
        self.p4_1 = nn.MaxPool2D(pool_size=3, strides=1, padding=1)
        self.p4_2 = nn.Conv2D(c4, kernel_size=1, activation='relu')

    def hybrid_forward(self, F, x):
        p1 = self.p1_1(x)
        p2 = self.p2_2(self.p2_1(x))
        p3 = self.p3_2(self.p3_1(x))
        p4 = self.p4_2(self.p4_1(x))
        # Concatenate the outputs on the channel dimension
        return np.concatenate((p1, p2, p3, p4), axis=1)
        #return F.concat(p1, p2, p3, p4, dim=1)

class GoogLeNet(nn.HybridBlock):
    """
    GoogLeNet uses a stack of a total of 9 inception blocks and global average pooling
    """
    def __init__(self, classes=1000, **kwargs):
        super().__init__(**kwargs)

        self.net = nn.HybridSequential()
        
        # First component uses a 64-channel 7 x 7 convolutional layer
        self.net.add(
            nn.Conv2D(64, kernel_size=7, strides=2, padding=3, activation='relu'),
            nn.MaxPool2D(pool_size=3, strides=2, padding=1)
        )

        # Second component uses two convolutional layers:
        # first a 64-channel 1 x 1 convolutional layer,
        # then a 3 x 3 convolutional layer that triples the number of channels.
        # This corresponds to the second path in the Inception block.
        self.net.add(
            nn.Conv2D(64, kernel_size=1, activation='relu'),
            nn.Conv2D(192, kernel_size=3, padding=1, activation='relu'),
            nn.MaxPool2D(pool_size=3, strides=2, padding=1)
        )
    
        # Third component connects to complete Inception blocks in series
        # The number of output channels of the first block is 64+128+32+32=256
        # and the ratio to the output channels of the four paths is 2:4:1:1.
        # The number of output channels of the second block is 128+192+96+64=480
        # and the ratio to the output channels per path is 4:6:3:2
        self.net.add(
            Inception(64, (96, 128), (16, 32), 32),
            Inception(128, (128, 192), (32, 96), 64)
        )

        # Fourth component connects five Inception blocks in series
        self.net.add(
            Inception(196, (96, 208), (16, 48), 64),
            Inception(160, (112, 224), (24, 64), 64),
            Inception(128, (128, 256), (24, 64), 64),
            Inception(112, (144, 288), (32, 64), 64),
            Inception(256, (160, 320), (32, 128), 128),
            nn.MaxPool2D(pool_size=3, strides=2, padding=1)
        )

        # Fifth component has two Inception blocks followed by output layer
        self.net.add(
            Inception(256, (160, 320), (32, 128), 128),
            Inception(384, (192, 384), (48, 128), 128),
            nn.Dense(classes)
        )
    
    def hybrid_forward(self, F, x):
        x = self.net(x)
        return x


net = GoogLeNet(classes=10)
net.initialize()
net.hybridize()

train_dl, valid_dl = d2l.load_data_fashion_mnist(batch_size=128, resize=96)

loss = gluon.loss.SoftmaxCrossEntropyLoss()
optimizer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})

d2l.train_ch13(net, train_iter=train_dl, test_iter=valid_dl, 
               loss=loss, trainer=optimizer, 
               num_epochs=10, ctx_list=ctx_list)

With return np.concatenate((p1, p2, p3, p4), axis=1) I get the error:

AssertionError: Positional arguments must have NDArray type, but got <_Symbol conv3_relu_fwd>

And with return F.concat(p1, p2, p3, p4, dim=1) I get the error:

TypeError: Operator concat registered in backend is known as concat in Python. This is a legacy operator which can only accept legacy ndarrays, while received an MXNet numpy ndarray. Please call as_nd_ndarray() upon the numpy ndarray to convert it to a legacy ndarray, and then feed the converted array to this operator.

Any suggestions on how I can modify this for hybridization would be greatly appreciated!

Turns out I needed F.np.concat.