Gluon.contrib.estimator make loss unchaned

Hi, everyone, I am trying gluon.contrib.estimator, however, here is some problem:

import os
def train_with_estimator(train_df, val_df, img_dir, batch_size, epoches, lr=0.01, ctx=mx.cpu()):
    # TODO: finish trainer .etc, add ctx
    steel_dataset = SteelDataset(train_df, img_dir)
    steel_data = data.DataLoader(steel_dataset, batch_size=batch_size, num_workers=4, shuffle=True)
    
    val_steel_dataset = SteelDataset(val_df, img_dir)
    val_steel_data = data.DataLoader(val_steel_dataset, batch_size=batch_size, num_workers=4, shuffle=False)
    
    celoss = loss.SigmoidBCELoss()
    diceloss = dice_loss(weight=1.0, batch_axis=0)
    unet = SteelUnet(ctx=ctx)
    dc_metric = mx.metric.CustomMetric(feval=dice_metric, name='dice_coefficent')
    ce_metric = mx.metric.CustomMetric(feval=ce_loss_metric, name='ce_loss')
    unet.collect_params(".*weight").initialize(mx.initializer.Normal(0.02), ctx=ctx)
    unet.collect_params(".*bias").initialize(mx.initializer.Zero(), ctx=ctx)
    unet.collect_params(".*mean").initialize(mx.initializer.Zero(), ctx=ctx)
    unet.collect_params(".*gamma || .*beta").initialize(ctx=ctx)
    lr_sche = lrs.FactorScheduler(step=20000, base_lr=lr, factor=0.7,  warmup_steps=10, warmup_begin_lr=0.00002)
    trainer = Trainer(unet.collect_params(), 'sgd', 
                        {'learning_rate': lr,
                         'wd':1e-5
#                          'lr_scheduler': lr_sche
                        })
    
    if not os.path.exists('model'):
        os.mkdir('model')
    unet.hybridize()
    model = Estimator(net=unet, trainer=trainer, loss=[celoss, diceloss], metrics=[dc_metric,ce_metric], initializer=mx.init.Normal(0.02), context=[ctx])
    stop_handler = mx.gluon.contrib.estimator.StoppingHandler(max_epoch=epoches, max_batch=epoches * len(steel_dataset)//batch_size)
    log_handler = mx.gluon.contrib.estimator.LoggingHandler(file_name="unet.log", file_location='.', train_metrics=[dc_metric, ce_metric],\
                                                           val_metrics=[dc_metric, ce_metric],
                                                           verbose=2)
    early_stop = mx.gluon.contrib.estimator.EarlyStoppingHandler(monitor=dc_metric, patience=5)
    ckpt_handler = mx.gluon.contrib.estimator.CheckpointHandler(model_dir="model/", 
                                                                model_prefix="resnet18", 
                                                                monitor=dc_metric,
                                                                save_best=True,
                                                                verbose=1)
    model.fit(
        train_data=steel_data,
        val_data=val_steel_data,
        epochs=epoches,
        event_handlers=[stop_handler, log_handler, early_stop, ckpt_handler]
    )
    return model

Training log as following:

[Epoch 0] Begin, current learning rate: 0.0100
[Epoch 0][Batch 0][Samples 16] time/batch: 6.799s train dice_coefficent: 0.0170, train ce_loss: 0.7001
[Epoch 0][Batch 1][Samples 32] time/batch: 0.322s train dice_coefficent: 0.0130, train ce_loss: 0.6984
[Epoch 0][Batch 2][Samples 48] time/batch: 0.360s train dice_coefficent: 0.0119, train ce_loss: 0.6980
[Epoch 0][Batch 3][Samples 64] time/batch: 0.370s train dice_coefficent: 0.0130, train ce_loss: 0.6983
[Epoch 0][Batch 4][Samples 80] time/batch: 0.277s train dice_coefficent: 0.0136, train ce_loss: 0.6986
[Epoch 0][Batch 5][Samples 96] time/batch: 0.360s train dice_coefficent: 0.0134, train ce_loss: 0.6985
[Epoch 0][Batch 6][Samples 112] time/batch: 0.400s train dice_coefficent: 0.0128, train ce_loss: 0.6983
[Epoch 0][Batch 7][Samples 128] time/batch: 0.416s train dice_coefficent: 0.0136, train ce_loss: 0.6987
[Epoch 0][Batch 8][Samples 144] time/batch: 0.368s train dice_coefficent: 0.0143, train ce_loss: 0.6990
[Epoch 0][Batch 9][Samples 160] time/batch: 0.337s train dice_coefficent: 0.0149, train ce_loss: 0.6993
[Epoch 0][Batch 10][Samples 176] time/batch: 0.380s train dice_coefficent: 0.0140, train ce_loss: 0.6990
[Epoch 0][Batch 11][Samples 192] time/batch: 0.280s train dice_coefficent: 0.0136, train ce_loss: 0.6988
[Epoch 0][Batch 12][Samples 208] time/batch: 0.367s train dice_coefficent: 0.0137, train ce_loss: 0.6988
[Epoch 0][Batch 13][Samples 224] time/batch: 0.393s train dice_coefficent: 0.0135, train ce_loss: 0.6988
[Epoch 0][Batch 14][Samples 240] time/batch: 0.351s train dice_coefficent: 0.0138, train ce_loss: 0.6989
[Epoch 0][Batch 15][Samples 256] time/batch: 0.314s train dice_coefficent: 0.0142, train ce_loss: 0.6991
[Epoch 0][Batch 16][Samples 272] time/batch: 0.342s train dice_coefficent: 0.0135, train ce_loss: 0.6988
[Epoch 0][Batch 17][Samples 288] time/batch: 0.371s train dice_coefficent: 0.0136, train ce_loss: 0.6989
[Epoch 0][Batch 18][Samples 304] time/batch: 0.390s train dice_coefficent: 0.0131, train ce_loss: 0.6987
[Epoch 0][Batch 19][Samples 320] time/batch: 0.354s train dice_coefficent: 0.0128, train ce_loss: 0.6985
[Epoch 0][Batch 20][Samples 336] time/batch: 0.362s train dice_coefficent: 0.0125, train ce_loss: 0.6984
[Epoch 0][Batch 21][Samples 352] time/batch: 0.378s train dice_coefficent: 0.0124, train ce_loss: 0.6984
[Epoch 0][Batch 22][Samples 368] time/batch: 0.304s train dice_coefficent: 0.0121, train ce_loss: 0.6982
[Epoch 0][Batch 23][Samples 384] time/batch: 0.382s train dice_coefficent: 0.0123, train ce_loss: 0.6983
[Epoch 0][Batch 24][Samples 400] time/batch: 0.364s train dice_coefficent: 0.0127, train ce_loss: 0.6985
[Epoch 0][Batch 25][Samples 416] time/batch: 0.324s train dice_coefficent: 0.0127, train ce_loss: 0.6985
[Epoch 0][Batch 26][Samples 432] time/batch: 0.388s train dice_coefficent: 0.0138, train ce_loss: 0.6990
[Epoch 0][Batch 27][Samples 448] time/batch: 0.370s train dice_coefficent: 0.0135, train ce_loss: 0.6989
[Epoch 0][Batch 28][Samples 464] time/batch: 0.391s train dice_coefficent: 0.0134, train ce_loss: 0.6989
[Epoch 0][Batch 29][Samples 480] time/batch: 0.335s train dice_coefficent: 0.0132, train ce_loss: 0.6988
[Epoch 0][Batch 30][Samples 496] time/batch: 0.388s train dice_coefficent: 0.0135, train ce_loss: 0.6989
[Epoch 0][Batch 31][Samples 512] time/batch: 0.336s train dice_coefficent: 0.0132, train ce_loss: 0.6988
[Epoch 0][Batch 32][Samples 528] time/batch: 0.329s train dice_coefficent: 0.0134, train ce_loss: 0.6989
[Epoch 0][Batch 33][Samples 544] time/batch: 0.385s train dice_coefficent: 0.0132, train ce_loss: 0.6989
[Epoch 0][Batch 34][Samples 560] time/batch: 0.389s train dice_coefficent: 0.0131, train ce_loss: 0.6988

模型为基于resnet18的unet, dice_loss为如下:

class dice_loss(mx.gluon.loss.Loss):
    def __init(self, weight, batch_axis, **kwargs):
        super(dice_loss, self).__init__(weight, batch_axis, **kwargs)
        
    def hybrid_forward(self, F, preds, labels, **kwargs):
        preds = F.sigmoid(preds)
        preds = preds.flatten()
        labels = labels.flatten()

        intersections = F.sum(preds * labels, axis=1)
        unions = F.sum(preds + labels, axis=1)
        dice = 1. - (2. * intersections + 1e-5) / (unions + 1e-5)
        return F.mean(dice, axis=self._batch_axis)

What I have done:

  • change initializer to Xariver
  • adjust learning rate from 0.001 to 0.1
  • just use BCEloss

Hi,
I think the estimator does not support multiple loss for now, and it can only run backward on one loss to update the parameters. You might want to define a single composite loss. It can be weighted and pass it to estimator.
For example, if your loss weight can be [0.5, 0.5] and your final loss can have a hybrid forward like this:

def hybrid_forward(self, F, preds, labels, **kwargs):
        return loss_weight[0]*bce_loss(preds, labels) + loss_weight[1]*dice_loss(preds, labels)

Also you mentioned only using BCE loss is also not working, so maybe there are some other problems, could you try to run your model using imperative gluon way? (writing the training loop manually)

Thanks for your answer