Hi.
I found that the input label with wrong indexes for mx.sym.SoftmaxOutput gives no error or feedback.
Here is the code I wrote.
import mxnet as mx
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
focus = mx.sym.SoftmaxOutput(data=data, label=label)
# give valid label index value
args = {'data': mx.nd.random.uniform(0, 1, (4, 8)), 'label': mx.nd.ones((4,))}
c_exec = focus.simple_bind(ctx=mx.cpu(), data=(4, 8), label=(4,), grad_req='write')
c_exec.copy_params_from(arg_params=args)
c_exec.forward()
c_exec.backward()
print(c_exec.grad_arrays)
Here, I give invalid label indexes to the model, and it does not give any error feedback. Also, it gives calculated gradients of the model arguments. ( here, I didn’t provide ignore_index)
args['label']=mx.nd.ones((4,)) * 1234 # invalid label index
c_exec = focus.simple_bind(ctx=mx.cpu(), data=(4, 8), label=(4,), grad_req='write')
c_exec.copy_params_from(arg_params=args)
c_exec.forward()
c_exec.backward()
print(c_exec.grad_arrays)
I don’t understand how the model calculates gradients with wrong indexes. What’s going on with this code?
The gradient of cross entropy w.r.t sotmax output is:
gradient=output−label
In the case where label=1234, what happens is that the ground truth value for each of the 8 elements of the prediction is 0.
So the gradient is actually the output. It is confirmed here:
args['label']=(mx.nd.ones((4,))*1234).astype('int32') # invalid label index
c_exec = focus.simple_bind(ctx=mx.cpu(), data=(4, 8), label=(4,), grad_req='write')
c_exec.forward(**args)
c_exec.backward()
print(c_exec.output_dict)
print(c_exec.grad_dict)
{'softmaxoutput0_output':
[[0.15378468 0.08179748 0.11394583 0.08715329 0.15388192 0.16731301
0.12735543 0.1147683 ]
[0.14141993 0.1324424 0.11304702 0.11670923 0.1781388 0.08622951
0.13038214 0.10163102]
[0.11895253 0.14455761 0.14277396 0.12517804 0.07936542 0.12458669
0.10526115 0.15932466]
[0.13836394 0.09536436 0.09556355 0.10491768 0.13265201 0.15124239
0.10978047 0.17211556]]
<NDArray 4x8 @cpu(0)>}
{'data':
[[ 0.15378468 0.08179748 0.11394583 0.08715329 -0.8461181 0.16731301
0.12735543 0.1147683 ]
[ 0.14141993 0.1324424 0.11304702 0.11670923 -0.8218612 0.08622951
0.13038214 0.10163102]
[ 0.11895253 0.14455761 0.14277396 0.12517804 -0.92063457 0.12458669
0.10526115 0.15932466]
[ 0.13836394 0.09536436 0.09556355 0.10491768 -0.86734796 0.15124239
0.10978047 0.17211556]]
I agree that when the label is out of range of the index it shouldn’t be silently accepted. I am not aware of a use-case where it would make sense. I think a warning or even an error would be a better user experience here.
Thank you for replying. I also agree that an error would be better for users.