Finetuning in MXNet for ConvNet blocks

Hi,

I guess you should wait for input from mxnet experts, but this will get you started (with gluon).

Every variable in a network is stored in a gluon.Parameter object. This has the property grad_req with the following three options: write, for updating the gradients (thus modifying them), null for non trainable variables (parameters), and add for adding in place (not sure for the inplace), the new gradients when trainer.step is called. So in your case, you can view explicitly which layer in a pre-trained network has grad_req == null that you want to train, and change that. I am under the impression - but I may be wrong - that pretrained networks are not by default non-trainable (even when you load with pretrained=True). Let’s see an example I just run:

import essentials

import mxnet as mx
from mxnet import gluon

# This loads a pre-trained network
net = gluon.model_zoo.vision.resnet18_v2(pretrained=True)

if you run net in the shell, you’ll see the layers of the network:

net

# prints
ResNetV2(
  (features): HybridSequential(
    (0): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=True, use_global_stats=False, in_channels=3)
    (1): Conv2D(3 -> 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
    (3): Activation(relu)
    (4): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False)
    (5): HybridSequential(
      (0): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        (conv1): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        (conv1): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (6): HybridSequential(
      (0): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        (conv1): Conv2D(64 -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (downsample): Conv2D(64 -> 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      )
      (1): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        (conv1): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (7): HybridSequential(
      (0): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        (conv1): Conv2D(128 -> 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (downsample): Conv2D(128 -> 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
      )
      (1): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        (conv1): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (8): HybridSequential(
      (0): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        (conv1): Conv2D(256 -> 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
        (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (downsample): Conv2D(256 -> 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      )
      (1): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
        (conv1): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
        (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (9): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
    (10): Activation(relu)
    (11): GlobalAvgPool2D(size=(1, 1), stride=(1, 1), padding=(0, 0), ceil_mode=True)
    (12): Flatten
  )
  (output): Dense(512 -> 1000, linear)
)

Now, in this pre-trained network you can see there exist net.features that is a HybridSequential container of which you can access the layers with simple indexing (just like a python list). For example, say I want to see the grad_req value (string) of the 2nd layer of features (1st: BatchNorm, 2nd: Conv2D)

net.features[1] # Note that index 0 refers to the BatchNorm layer

# prints
Conv2D(3 -> 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

Let’s see the grad_req:

for param in net.features[1].collect_params().values():
    print (param.name,", ", param.grad_req)

# prints, observe no bias since it follows a BatchNorm (```bias=False```)
resnetv20_conv0_weight ,  write

so convolutional weight is writable (i.e. trainable). If you want to freeze some of the variables/layers of the pre-trained network (so they are not updated), you need to do the above loop, with assignment, where grad_req == 'write':

for param in net.features[1].collect_params().values(): # Or some other layers that you want. 
    param.grad_req='null'

The above is my summary of understanding on pre-trained networks, I maybe wrong, but you can experiment. Hope this helps.

3 Likes