Hi, everyone, I am trying gluon.contrib.estimator
, however, here is some problem:
import os
def train_with_estimator(train_df, val_df, img_dir, batch_size, epoches, lr=0.01, ctx=mx.cpu()):
# TODO: finish trainer .etc, add ctx
steel_dataset = SteelDataset(train_df, img_dir)
steel_data = data.DataLoader(steel_dataset, batch_size=batch_size, num_workers=4, shuffle=True)
val_steel_dataset = SteelDataset(val_df, img_dir)
val_steel_data = data.DataLoader(val_steel_dataset, batch_size=batch_size, num_workers=4, shuffle=False)
celoss = loss.SigmoidBCELoss()
diceloss = dice_loss(weight=1.0, batch_axis=0)
unet = SteelUnet(ctx=ctx)
dc_metric = mx.metric.CustomMetric(feval=dice_metric, name='dice_coefficent')
ce_metric = mx.metric.CustomMetric(feval=ce_loss_metric, name='ce_loss')
unet.collect_params(".*weight").initialize(mx.initializer.Normal(0.02), ctx=ctx)
unet.collect_params(".*bias").initialize(mx.initializer.Zero(), ctx=ctx)
unet.collect_params(".*mean").initialize(mx.initializer.Zero(), ctx=ctx)
unet.collect_params(".*gamma || .*beta").initialize(ctx=ctx)
lr_sche = lrs.FactorScheduler(step=20000, base_lr=lr, factor=0.7, warmup_steps=10, warmup_begin_lr=0.00002)
trainer = Trainer(unet.collect_params(), 'sgd',
{'learning_rate': lr,
'wd':1e-5
# 'lr_scheduler': lr_sche
})
if not os.path.exists('model'):
os.mkdir('model')
unet.hybridize()
model = Estimator(net=unet, trainer=trainer, loss=[celoss, diceloss], metrics=[dc_metric,ce_metric], initializer=mx.init.Normal(0.02), context=[ctx])
stop_handler = mx.gluon.contrib.estimator.StoppingHandler(max_epoch=epoches, max_batch=epoches * len(steel_dataset)//batch_size)
log_handler = mx.gluon.contrib.estimator.LoggingHandler(file_name="unet.log", file_location='.', train_metrics=[dc_metric, ce_metric],\
val_metrics=[dc_metric, ce_metric],
verbose=2)
early_stop = mx.gluon.contrib.estimator.EarlyStoppingHandler(monitor=dc_metric, patience=5)
ckpt_handler = mx.gluon.contrib.estimator.CheckpointHandler(model_dir="model/",
model_prefix="resnet18",
monitor=dc_metric,
save_best=True,
verbose=1)
model.fit(
train_data=steel_data,
val_data=val_steel_data,
epochs=epoches,
event_handlers=[stop_handler, log_handler, early_stop, ckpt_handler]
)
return model
Training log as following:
[Epoch 0] Begin, current learning rate: 0.0100
[Epoch 0][Batch 0][Samples 16] time/batch: 6.799s train dice_coefficent: 0.0170, train ce_loss: 0.7001
[Epoch 0][Batch 1][Samples 32] time/batch: 0.322s train dice_coefficent: 0.0130, train ce_loss: 0.6984
[Epoch 0][Batch 2][Samples 48] time/batch: 0.360s train dice_coefficent: 0.0119, train ce_loss: 0.6980
[Epoch 0][Batch 3][Samples 64] time/batch: 0.370s train dice_coefficent: 0.0130, train ce_loss: 0.6983
[Epoch 0][Batch 4][Samples 80] time/batch: 0.277s train dice_coefficent: 0.0136, train ce_loss: 0.6986
[Epoch 0][Batch 5][Samples 96] time/batch: 0.360s train dice_coefficent: 0.0134, train ce_loss: 0.6985
[Epoch 0][Batch 6][Samples 112] time/batch: 0.400s train dice_coefficent: 0.0128, train ce_loss: 0.6983
[Epoch 0][Batch 7][Samples 128] time/batch: 0.416s train dice_coefficent: 0.0136, train ce_loss: 0.6987
[Epoch 0][Batch 8][Samples 144] time/batch: 0.368s train dice_coefficent: 0.0143, train ce_loss: 0.6990
[Epoch 0][Batch 9][Samples 160] time/batch: 0.337s train dice_coefficent: 0.0149, train ce_loss: 0.6993
[Epoch 0][Batch 10][Samples 176] time/batch: 0.380s train dice_coefficent: 0.0140, train ce_loss: 0.6990
[Epoch 0][Batch 11][Samples 192] time/batch: 0.280s train dice_coefficent: 0.0136, train ce_loss: 0.6988
[Epoch 0][Batch 12][Samples 208] time/batch: 0.367s train dice_coefficent: 0.0137, train ce_loss: 0.6988
[Epoch 0][Batch 13][Samples 224] time/batch: 0.393s train dice_coefficent: 0.0135, train ce_loss: 0.6988
[Epoch 0][Batch 14][Samples 240] time/batch: 0.351s train dice_coefficent: 0.0138, train ce_loss: 0.6989
[Epoch 0][Batch 15][Samples 256] time/batch: 0.314s train dice_coefficent: 0.0142, train ce_loss: 0.6991
[Epoch 0][Batch 16][Samples 272] time/batch: 0.342s train dice_coefficent: 0.0135, train ce_loss: 0.6988
[Epoch 0][Batch 17][Samples 288] time/batch: 0.371s train dice_coefficent: 0.0136, train ce_loss: 0.6989
[Epoch 0][Batch 18][Samples 304] time/batch: 0.390s train dice_coefficent: 0.0131, train ce_loss: 0.6987
[Epoch 0][Batch 19][Samples 320] time/batch: 0.354s train dice_coefficent: 0.0128, train ce_loss: 0.6985
[Epoch 0][Batch 20][Samples 336] time/batch: 0.362s train dice_coefficent: 0.0125, train ce_loss: 0.6984
[Epoch 0][Batch 21][Samples 352] time/batch: 0.378s train dice_coefficent: 0.0124, train ce_loss: 0.6984
[Epoch 0][Batch 22][Samples 368] time/batch: 0.304s train dice_coefficent: 0.0121, train ce_loss: 0.6982
[Epoch 0][Batch 23][Samples 384] time/batch: 0.382s train dice_coefficent: 0.0123, train ce_loss: 0.6983
[Epoch 0][Batch 24][Samples 400] time/batch: 0.364s train dice_coefficent: 0.0127, train ce_loss: 0.6985
[Epoch 0][Batch 25][Samples 416] time/batch: 0.324s train dice_coefficent: 0.0127, train ce_loss: 0.6985
[Epoch 0][Batch 26][Samples 432] time/batch: 0.388s train dice_coefficent: 0.0138, train ce_loss: 0.6990
[Epoch 0][Batch 27][Samples 448] time/batch: 0.370s train dice_coefficent: 0.0135, train ce_loss: 0.6989
[Epoch 0][Batch 28][Samples 464] time/batch: 0.391s train dice_coefficent: 0.0134, train ce_loss: 0.6989
[Epoch 0][Batch 29][Samples 480] time/batch: 0.335s train dice_coefficent: 0.0132, train ce_loss: 0.6988
[Epoch 0][Batch 30][Samples 496] time/batch: 0.388s train dice_coefficent: 0.0135, train ce_loss: 0.6989
[Epoch 0][Batch 31][Samples 512] time/batch: 0.336s train dice_coefficent: 0.0132, train ce_loss: 0.6988
[Epoch 0][Batch 32][Samples 528] time/batch: 0.329s train dice_coefficent: 0.0134, train ce_loss: 0.6989
[Epoch 0][Batch 33][Samples 544] time/batch: 0.385s train dice_coefficent: 0.0132, train ce_loss: 0.6989
[Epoch 0][Batch 34][Samples 560] time/batch: 0.389s train dice_coefficent: 0.0131, train ce_loss: 0.6988
模型为基于resnet18的unet, dice_loss为如下:
class dice_loss(mx.gluon.loss.Loss):
def __init(self, weight, batch_axis, **kwargs):
super(dice_loss, self).__init__(weight, batch_axis, **kwargs)
def hybrid_forward(self, F, preds, labels, **kwargs):
preds = F.sigmoid(preds)
preds = preds.flatten()
labels = labels.flatten()
intersections = F.sum(preds * labels, axis=1)
unions = F.sum(preds + labels, axis=1)
dice = 1. - (2. * intersections + 1e-5) / (unions + 1e-5)
return F.mean(dice, axis=self._batch_axis)
What I have done:
- change initializer to Xariver
- adjust learning rate from 0.001 to 0.1
- just use BCEloss