Alternated training parts of a model with mx.sym


Suppose you have a model that has two parts (partA and partB, partA+partB = whole model), you want to train partA for a while leaving partB fixed, and then train partB while leaving partA fixed.

In Gluon, it seems like I can do this by using something the following

partA_trainer = gluon.Trainer(net.partA.collect_params(), 'adam', {'learning_rate': lr})
partB_trainer = gluon.Trainer(net.partB.collect_params(), 'adam', {'learning_rate': lr})


with autograd.record():
    loss = net(data)

if (epoch // alternate_epochs) % 2 == 0:
    training_info = '<<< partA training only >>>'
    training_info = '<<< partB training only >>>'

I wonder if there’s an easy way to do similar things with mx.sym API


This github issue seems to be what you are looking for:

Freeze the gradients of the parts that you don’t want to update