Modify structure of loaded network


#1

I want to load a pretrained model in gluon, modify its structure, then fine-tune and predict. For example, I want to replace all convolutional layers with a custom layer that has the same input and output shapes. How can I do this?

I’ve tried directly assigning to net._children[index] but this has no effect on execution. There are some examples of how to do this with the final layer, but I have not found any examples or description of how to do this with an intermediate layer. Also I would like to be able to, for example, replace all instances of Conv2D in a network automatically (examples only show how to manually identify a layer by its name and extract its output).


#2

One way to do it that is not without caveats is to iterate through your layers and replace the ones you want with your existing ones. The major issue with that is if you have already a custom layer, with let’s say a hybrid_forward function that uses F.Convolution

Here is an example, where I replace the Conv2D of resnet with Conv2D that are using a bias.

net = gluon.model_zoo.vision.resnet18_v1(pretrained=True)

type(net.features[0].bias)
NoneType
def replace_conv2D(net):
    for key, layer in net._children.items():
        if isinstance(layer, gluon.nn.Conv2D):
            new_conv = gluon.nn.Conv2D(
                channels=layer._channels, 
                kernel_size=layer._kwargs['kernel'], 
                strides=layer._kwargs['stride'], 
                padding=layer._kwargs['pad'], 
                use_bias=True)
            with net.name_scope():
                net.register_child(new_conv, key)
            new_conv.initialize(mx.init.Xavier())
            print('Replacing layer')

        # Recursively replace layers
        else:
            replace_conv2D(layer)

replace_conv2D(net)
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
Replacing layer
print(net(mx.nd.ones((3,224,224,224))))
[[  2.57089067  15.61428833  -8.22652912 ..., -29.74064636  -3.43608236
   -6.216084  ]
 [  2.57089067  15.61428833  -8.22652912 ..., -29.74064636  -3.43608236
   -6.216084  ]
 [  2.57089067  15.61428833  -8.22652912 ..., -29.74064636  -3.43608236
   -6.216084  ]]
<NDArray 3x1000 @cpu(0)>
print(type(net.features[0].bias))
mxnet.gluon.parameter.Parameter

#3

Thanks, Thomas. What version of mxnet are you using to run this code?

I’m using mxnet 1.1 and this line fails for me:

for key, layer in net._children.items():

Here is the error:

AttributeError: 'list'`` object has no attribute 'items'


#4

The error is gone with mxnet 1.2.1


#5

The code you provided above runs without error using mxnet 1.2.1, but it doesn’t seem to achieve the desired effect. I don’t see the change reflected when I run a forward pass on the modified net.

Here’s a demonstration with a toy example:

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.ndarray import tanh, relu

# define a layer that will raise an exception on forward
class BrokenConv(nn.Conv2D):
    def forward(*x, **y):
        raise Exception()
    def hybrid_forward(*x, **y):
        raise Exception()

# define standard LeNet
class LeNet(gluon.HybridBlock):
    def __init__(self, kernel_size=(5,5), num_filters=(20, 50), pool_size=(2,2), strides=(2,2), ff_hidden=500, **kwargs):
        super(LeNet, self).__init__(**kwargs)
        with self.name_scope():
            self.conv1 = nn.Conv2D(num_filters[0], kernel_size=kernel_size)
            self.pool1 = nn.MaxPool2D(pool_size=pool_size, strides=strides)
            self.conv2 = nn.Conv2D(num_filters[1], kernel_size=kernel_size)
            self.pool2 = nn.MaxPool2D(pool_size=pool_size, strides=strides)
            self.fc1 = nn.Dense(ff_hidden)
            self.fc2 = nn.Dense(10)
    def forward(self, x):
        x = self.pool1(tanh(self.conv1(x)))
        x = self.pool2(tanh(self.conv2(x)))
        x = x.reshape((0, -1))
        x = tanh(self.fc1(x))
        x = tanh(self.fc2(x))
        return x

# get a net
net = LeNet()
net.hybridize()
ctx = [mx.cpu()]
net.collect_params().initialize(ctx=ctx)

# run a forward pass
data = mx.nd.ones((1024, 3, 32, 32))
out = net(data)

def replace_conv2D(net):
    for key, layer in net._children.items():
        if isinstance(layer, gluon.nn.Conv2D):
            new_conv = BrokenConv(
                channels=layer._channels, 
                kernel_size=layer._kwargs['kernel'], 
                strides=layer._kwargs['stride'], 
                padding=layer._kwargs['pad'], 
                use_bias=True)
            with net.name_scope():
                net.register_child(new_conv, key)
            new_conv.initialize(mx.init.Xavier())
            print('Replacing layer')
        # Recursively replace layers
        else:
            replace_conv2D(layer)

replace_conv2D(net)

out = net(data)

That last line should raise an exception, but it doesn’t. If I use BrokenConv directly in the definition of LeNet and run a forward pass, it raises an exception as expected.


#6

Try this:

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.ndarray import tanh, relu

# define a layer that will raise an exception on forward
class BrokenConv(nn.Conv2D):
    def forward(*x, **y):
        raise Exception()
    def hybrid_forward(*x, **y):
        raise Exception()

# define standard LeNet
class LeNet(gluon.HybridBlock):
    def __init__(self, kernel_size=(5,5), num_filters=(20, 50), pool_size=(2,2), strides=(2,2), ff_hidden=500, **kwargs):
        super(LeNet, self).__init__(**kwargs)
        with self.name_scope():
            self.conv1 = nn.Conv2D(num_filters[0], kernel_size=kernel_size)
            self.pool1 = nn.MaxPool2D(pool_size=pool_size, strides=strides)
            self.conv2 = nn.Conv2D(num_filters[1], kernel_size=kernel_size)
            self.pool2 = nn.MaxPool2D(pool_size=pool_size, strides=strides)
            self.fc1 = nn.Dense(ff_hidden)
            self.fc2 = nn.Dense(10)
    def forward(self, x):
        x = self.pool1(tanh(self.conv1(x)))
        x = self.pool2(tanh(self.conv2(x)))
        x = x.reshape((0, -1))
        x = tanh(self.fc1(x))
        x = tanh(self.fc2(x))
        return x

# get a net
net = LeNet()

ctx = [mx.cpu()]
net.initialize(ctx=ctx)

# run a forward pass
data = mx.nd.ones((1024, 3, 32, 32))

def replace_conv2D(net):
    for key, layer in net._children.items():
        if isinstance(layer, gluon.nn.Conv2D):
            new_conv = BrokenConv(
                channels=layer._channels, 
                kernel_size=layer._kwargs['kernel'], 
                strides=layer._kwargs['stride'], 
                padding=layer._kwargs['pad'], 
                use_bias=True)
            with net.name_scope():
                if hasattr(net, key):
                    setattr(net, key, new_conv)
                net.register_child(new_conv, key)
            new_conv.initialize(mx.init.Xavier())
            print('Replacing layer '+key)
        # Recursively replace layers
        else:
            replace_conv2D(layer)

replace_conv2D(net)

out = net(data)
<ipython-input-210-55f25459abc5> in forward(*x, **y)
      7 class BrokenConv(nn.Conv2D):
      8     def forward(*x, **y):
----> 9         raise Exception()
     10     def hybrid_forward(*x, **y):
     11         raise Exception()

Exception: 

Updated to set the attribute as well. This is due to the duality between registering the children for example for a HybridSequential where they are stored only in the _children ordered dict, and custom blocks that have blocks as properties, and forward passes that reference them directly by attritube names.

Obviously this is a bit hacky and might not resist the test of time. The ideal way is to reconstruct a new network and cherry pick what you need on the other one by iterating through it, which is easy to do if you know the structure of your network, but quite hacky if you were to build a completely generic method to do that.