Fix weights of layers in the .fit() function


Hi guys,

I know I could use the fixed_param_names parameter to fix the weights of certain layers. However, this parameter seems only available when creating/initializing a new Module object.

Say my network has two layers: layer0 and layer1, and with pre-trained weights loaded.

Is it possible to do the following?

Now I want to train 10 epochs.

In the first 5 epochs:
fix layer0’s weights, and update layer1’s weights only

In the next 5 epochs:
fix layer1’s weights, and update layer0’s weights only

I thought this would be very easy to do if the fit() function supports the fixed_param_names parameter.
Or this can be done easily by other methods?