Does anyone know what’s the problem with following code? Does it have anything to do with recreate gluon.Trainer every batch?
from mxnet.gluon import utils as gutils import datetime import sys sys.path.append('..') import utils def multi_gpu_train_batch(net,data,label,batch_size,lr,momentum,weight_decay,contexts): trainer=gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':lr,'momentum':momentum,'wd':weight_decay}) data_list=gutils.split_and_load(data,contexts) label_list=gutils.split_and_load(label,contexts) ls=[] with autograd.record(): i=0 for each_data,each_label in zip(data_list,label_list): output=net(each_data) loss=criterion(output,each_label) ls.append(loss) if i==0: outputs=output losses=loss else: outputs=nd.concat(outputs,output.as_in_context(outputs.context),dim=0) losses=nd.concat(losses,loss.as_in_context(losses.context),dim=0) i=i+1 for l in ls: l.backward() trainer.step(batch_size) return outputs,losses def multi_gpu_train(net,train_data,valid_data,num_epochs,batch_size, lr,momentum,weight_decay,contexts,lr_period,lr_decay): prev_time=datetime.datetime.now() i=0 length=len(lr_period) for epoch in range(num_epochs): train_loss=0.0 train_acc=0.0 if epoch>0 and i<length and epoch==lr_period[i]: lr=lr*lr_decay i=i+1 for data,label in train_data: output,loss=multi_gpu_train_batch(net,data,label,batch_size,lr,momentum,weight_decay,contexts) nd.waitall() train_loss += nd.mean(loss).asscalar() train_acc += utils.accuracy(output.as_in_context(mx.cpu(0)),label) cur_time=datetime.datetime.now() h,remainder = divmod((cur_time-prev_time).seconds,3600) m,s=divmod(remainder,60) time_str=" Time : %02dhour %02dmin %02dsec" % (h,m,s) if valid_data is not None: valid_acc = utils.evaluate_accuracy(valid_data,net,contexts) epoch_str= ("epoch:%d ,loss:%f ,train_acc:%f , valid_acc:%f" % (epoch,train_loss/len(train_data),train_acc/len(train_data),valid_acc)) else: epoch_str= ("epoch:%d ,loss:%f ,train_acc:%f" % (epoch,train_loss/len(train_data),train_acc/len(train_data))) prev_time=cur_time print(time_str+epoch_str+' learning_rate:'+str(lr))
I get around 6% boost in accuracy on Cifar10 by using a very simple train function:
def train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_period, lr_decay): trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr, 'momentum': 0.9, 'wd': wd}) prev_time = datetime.datetime.now() for epoch in range(num_epochs): if epoch > 0 and epoch % lr_period == 0: trainer.set_learning_rate(trainer.learning_rate * lr_decay) for X, y in train_data: y = y.astype('float32').as_in_context(ctx) with autograd.record(): y_hat = net(X.as_in_context(ctx)) l = loss(y_hat, y) l.backward() trainer.step(batch_size)