Say I have a network define with mxnet.symbol.
And my loss is depended on the batch size of the feature map.
x = get_network() # x is a symbol batch_size = get_batch_size() # THE QUESTION diag_mask = mx.symbol.eye(batch_size) loss = get_loss(x, diag_mask)
My dataset size cannot be divisible by batchsize, so when the program use the last batch of my dataset, an error would be arise that dimension mismatch.
How can I get the batch size like
tf.shape(x) in tensorflow?