I think I’m misunderstanding the use of metric reset/get or something, because as I progress through batches in an epoch on my validation set, the accuracy starts high and then steadily decreases.
valid_acc.reset()
cumulative_valid_loss = mx.nd.zeros(1, ctx)
valid_samples = 0
tbar = tqdm(valid_dl)
for batch_idx, (data, label) in enumerate(tbar):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
output = self.net(data)
loss = loss_fn(output, label)
cumulative_valid_loss += loss.sum()
valid_samples += data.shape[0]
valid_loss = cumulative_valid_loss.asscalar()/valid_samples
valid_acc.update(label, output)
metric_name, metric_val = valid_acc.get()
tbar.set_description(f'validation loss {valid_loss:.3f}, '
f'validation {metric_name}: {metric_val:.3f}')
Does anybody know why this is? All the examples I’ve seen only call metric.get()
at the end of an epoch. Is there a reason for that?