Why is there no simple NetDrop-like function for net surgery - would be very useful


#1

For example, in WolframAlpha’s new nnet toolbox, they use MxNet as backend, and provide some very helpful functions for net surgery, such as NetDrop:
http://reference.wolfram.com/language/ref/NetDrop.html

This is also provided by pyTorch to drop a single layer or otherwise modify. However, AFAIK, and I’ve been looking around for awhile, there is no simple way to do this in MxNet. The easiest way seems to be to define your new net from scratch to match the network you’re pruning, and then copy params over. Am I missing something in MxNet?


#2

Hi @ebeall,

Assuming you’re working with Sequential block models, this is possible in MXNet Gluon without duplicating the network but you’ll need to reference an internal property at the moment. Sequential simply means the output from one layer gets passed as input to the next layer, and is the format of model used my WolframAlpha looking at the docs you referenced.

A Sequential block has children (the internal property for this is called _children), and you can remove and switch these as you wish. Check out an example of this below. But be careful, because you need the input shapes to be compatible after removing layers as noted in the WolframAlpha docs too: “When removing layers from the interior of a NetChain, the input shape of the first removed layer and the output shape of the last removed layer must match.”

import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn

net = nn.Sequential()
net.add(nn.Conv2D(channels=8, kernel_size=3, padding=1, activation='relu'))
net.add(nn.Conv2D(channels=8, kernel_size=3, padding=1, activation='relu'))
net.add(nn.Conv2D(channels=8, kernel_size=3, padding=1, activation='relu'))
net.add(nn.Conv2D(channels=8, kernel_size=3, padding=1, activation='relu'))
net.add(nn.Conv2D(channels=8, kernel_size=3, padding=1, activation='relu'))
net.initialize()

data = nd.random.normal(shape=(1,3,32,32))
out = net(data)
print(net._children)
OrderedDict([('0',
              Conv2D(3 -> 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
             ('1',
              Conv2D(8 -> 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
             ('2',
              Conv2D(8 -> 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
             ('3',
              Conv2D(8 -> 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
             ('4',
              Conv2D(8 -> 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))])
# example of removing 3rd and 4th layers (index 2 and 3)
for i in range(2,4):
    del net._children[str(i)]
out = net(data)
net._children
OrderedDict([('0',
              Conv2D(3 -> 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
             ('1',
              Conv2D(8 -> 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
             ('4',
              Conv2D(8 -> 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))])