Freezing weight training for certain inputs to a hidden layer

Hi,

I am new to mxnet and gluon. I have the following use case. Can you please advise how to achieve this?

firstNet = gluon.nn.HybridSequential(prefix = ‘first’)
firstNet.add(gluon.nn.Dense(units = 5))
first = firstNet(X1)

secondNet = gluon.nn.HybridSequential(prefix = ‘second’)
secondNet.add(gluon.nn.Dense(units = 5))
second = secondNet(X2)

concatData = F.concat(first, second, dim = 1)

thirdNet = gluon.nn.HybridSequential(prefix = ‘third’)
thirdNet.add(gluon.nn.Dense(units = 2))

return thirdNet(concatData)

I need to freeze updating weights for all connections from secondNet to thirdNet after a phase of training.

Thanks

Hi @nasudhan,

You can use param.grad_req = 'null' to achieve this.

Gluon layers have parameters (for weights and biases) which are updated after calling trainer.step() by default. You can prevent the update of a parameter (i.e. freeze) by ignoring the calculation of the gradient, which has the added benefit of speeding up the backward pass slightly (by avoiding unnecessary calculation). You can freeze a whole layer by looping through all the parameters in a layer and setting param.grad_req = 'null'.

for param in net.second_net.collect_params().values():
    param.grad_req = 'null'
1 Like

So for your example:

import mxnet as mx

class MyNetwork(mx.gluon.nn.HybridBlock):
    def __init__(self, **kwargs):
        super(MyNetwork, self).__init__(**kwargs)
        with self.name_scope():
            self.first_net = mx.gluon.nn.HybridSequential()
            self.first_net.add(mx.gluon.nn.Dense(units = 3))
            self.second_net = mx.gluon.nn.HybridSequential()
            self.second_net.add(mx.gluon.nn.Dense(units = 2))
            self.third_net = mx.gluon.nn.HybridSequential()
            self.third_net.add(mx.gluon.nn.Dense(units = 1))
    
    def hybrid_forward(self, F, x1, x2):
        first = self.first_net(x1)
        second = self.second_net(x2)
        concat = F.concat(first, second, dim = 1)
        third = self.third_net(concat)
        return third

net = MyNetwork()
net.initialize()
trainer = mx.gluon.Trainer(net.collect_params(), optimizer='sgd', optimizer_params={'learning_rate': 1})
x1 = mx.nd.random.uniform(shape=(10, 4))
x2 = mx.nd.random.uniform(shape=(10, 4))
with mx.autograd.record():
    output = net(x1, x2)

net.second_net[0].weight.data()
[[-0.04054644 -0.0198587  -0.05195032  0.03509606]
 [-0.02584003  0.01509629 -0.01908049 -0.02449339]]
<NDArray 2x4 @cpu(0)>

Weights from initialisation.

Update: default

output.backward()
trainer.step(batch_size=x1.shape[0])
net.second_net[0].weight.data()
[[-0.04904126 -0.0277727  -0.06480464  0.02541063]
 [-0.05673689 -0.01368807 -0.06583347 -0.0597207 ]]
<NDArray 2x4 @cpu(0)>

Weights after single update step: change from initialisation.

Update: after freezing

for param in net.second_net.collect_params().values():
    param.grad_req = 'null'
with mx.autograd.record():
    output = net(x1, x2)
output.backward()
trainer.step(batch_size=x1.shape[0])
net.second_net[0].weight.data()
[[-0.04904126 -0.0277727  -0.06480464  0.02541063]
 [-0.05673689 -0.01368807 -0.06583347 -0.0597207 ]]
<NDArray 2x4 @cpu(0)>

Weights after another update step: same as before.

1 Like

Thank you. I understand this freezes weights for all inputs to second_net. But what I want is to freeze weights for all inputs from second_net to third_net. For example if first_net, second_net and third_net have 3, 2 and 2 nodes, respectively, shape of input weights for third_net would be (2, 5). I want to set grad_req to ‘null’ for all weights at [ : , -2 : ] and allow [ : , : 3] to be updated.

Oh sorry, I didn’t quite catch what you were trying to do the first time!

You can separate the layers in third_net for this. So rather than concat beforehand, just pass first through one layer and second through another layer, which you can freeze as per the method above, and then sum the results before the activation. Your network structure would look something like:

class MyNetwork(mx.gluon.nn.HybridBlock):
    def __init__(self, **kwargs):
        super(MyNetwork, self).__init__(**kwargs)
        self.fc1 = mx.gluon.nn.Dense(units = 5)
        self.fc2 = mx.gluon.nn.Dense(units = 3)
        self.fc3a = mx.gluon.nn.Dense(units = 2)
        self.fc3b = mx.gluon.nn.Dense(units = 2)
        self.act = mx.gluon.nn.Activation('sigmoid')
    
    def hybrid_forward(self, F, input1, input2):
        first = self.fc1(input1)
        second = self.fc2(input2)
        third_a = self.fc3a(first)
        third_b = self.fc3b(second)
        third = third_a + third_b
        return self.act(third)

What’s your use case for this by the way? I’ve not seen it done before.

Thanks a lot. This worked; however I am seeing weird results when I have activation with ‘relu’ for my third net after the concatenation operation. Code below - Unfreezing really does not work as expected when I have an activation for the thrid_net below.

class MyNetwork(mx.gluon.nn.HybridBlock):
    def __init__(self, **kwargs):
        super(MyNetwork, self).__init__(**kwargs)
        with self.name_scope():
            self.first_net = mx.gluon.nn.HybridSequential()
            self.first_net.add(mx.gluon.nn.Dense(units = 3))
            self.second_net = mx.gluon.nn.HybridSequential()
            self.second_net.add(mx.gluon.nn.Dense(units = 3))
            self.fc3a = mx.gluon.nn.Dense(units = 2) # this layer weights are frozen in first phase and unfrozen in the second phase.
            self.fc3b = mx.gluon.nn.Dense(units = 2)
            self.act = mx.gluon.nn.Activation('sigmoid')
            self.third_net = mx.gluon.nn.HybridSequential()
            self.third_net.add(mx.gluon.nn.Dense(units = 2))
            self.third_net.add(mx.gluon.nn.Activation(activation = 'relu')) # everything works fine without this activation. But with this activation, when I unfreeze the weight learning for fc3a in second stage model, I am not seeing the weights being updated.

    
    def hybrid_forward(self, F, x1, x2):
        first = self.first_net(x1)
        second = self.second_net(x2)
        third_a = self.fc3a(first)
        third_b = self.fc3b(second)
        third = third_a + third_b
        third_act = self.act(third)                                      
        third_net = self.third_net(third_act)
        return third_net

PS: I am exploring these for getting insights into a specific regression problem.

Sounds like the ReLUs are getting stuck, where the gradient back-propagation is being ‘blocked’ because the ReLU is in the negative range, and the gradient is zero. It’s a common issue with ReLU. Check the input to the ReLU (in the second phase) and check that some are positive at least, otherwise there will be no gradient.

You need to refactor your model slightly if you want to easily set a breakpoint for this. Avoid placing the ReLU in the third net, and just apply it in the hybrid_forward call using F.relu. You’ll then be able to set a breakpoint on this line. And don’t hybridize while debugging.

Thank you! After experimenting a bit, I am using Swish instead of relu to mitigate dying ReLU.

In the documentation I’ve found that you don’t need to explicitly run “for loop” to freeze a layer. You can simply write model.collect_params().setattr('grad_req', 'null') that will set grad_req to null for all the parameters. So in your case, you can write net.second_net.collect_params().setattr('grad_req', 'null')

Below is the full explanation of what we can do with collect_params:

  • It contains a list of str describing all parameter’s names.

    for param in model.collect_params():
        print(type(param))
        break
    # will print `<class 'str'>`, showing that its an str object.
    
  • collect_params().values() contains list of parameters as object of the gluon.Parameter class.

    for param in model.collect_params().values():
        print(type(param))
        break
    # will print `<class 'mxnet.gluon.parameter.Parameter'>`, showing that its an Parameter class object.
    
  • A gluon.Parameter has following properties that you can change:-

    • grad_req (‘write’ or ‘add’ or ‘null’)-
      which tells if the gradient is required for that parameter of not
    • lr_mult (float)
      which is local learning rate multiplier for this Parameter. The actual learning rate is calculated with learning_rate * lr_mult
    • wd_mult (float)
      which is local weight decay multiplier for this Parameter
  • If you want to set a property of all the parameters of all layers, simply do as below-

    model.collect_params().setattr('grad_req', 'null') # could be "setattr('wd_mult', 0.01).... etc"
    
  • If you want to set a property of a specific parameter of a specific layer, follow the procedure given below-

  1. Print the model
    print(model)
    # will print something like
    '''
    Sequential(
      (0): Dense(784 -> 512, Activation(relu))
      (1): Dense(512 -> 512, Activation(relu))
      (2): Dense(512 -> 256, Activation(relu))
      (3): Dense(256 -> 128, Activation(relu))
      (4): Dense(128 -> 10, linear)
    )
    '''
    
  2. Then you can select any layer by its index as shown above
    print(model[0])
    # will print
    # Dense(784 -> 512, Activation(relu))
    
  3. Then change any property you want to
    model[0].collect_params().setattr('grad_req', 'null') # could be "setattr('wd_mult', 0.01).... etc"
    # this will set "grad_req" of all weights and biases equal to "null"
    
  4. If you want to change a property of a specific parameter out of all the parameters of a specific layer, for example, in above if you want to change the grad_req to “null” only for weight not for the bias
    model[0].collect_params('.*weight').setattr('grad_req', 'null')
    # this will set "grad_req" to null for only weights, not for the bias.
    
    This example shows that collect_params() supports “regular expressions”, check this out for more info.

I am trying to implement Lottery ticket hypothesis, and it needs to randomly freeze parameters. Is it possible to do that?