Odd behavior of mx.sym.SoftmaxOutput


#1

Hi.
I found that the input label with wrong indexes for mx.sym.SoftmaxOutput gives no error or feedback.
Here is the code I wrote.

import mxnet as mx

data = mx.sym.Variable('data')
label = mx.sym.Variable('label')

focus = mx.sym.SoftmaxOutput(data=data, label=label)

# give valid label index value
args = {'data': mx.nd.random.uniform(0, 1, (4, 8)), 'label': mx.nd.ones((4,))}
c_exec = focus.simple_bind(ctx=mx.cpu(), data=(4, 8), label=(4,), grad_req='write')

c_exec.copy_params_from(arg_params=args)

c_exec.forward()
c_exec.backward()
print(c_exec.grad_arrays)

Here, I give invalid label indexes to the model, and it does not give any error feedback. Also, it gives calculated gradients of the model arguments. ( here, I didn’t provide ignore_index)

args['label']=mx.nd.ones((4,)) * 1234 # invalid label index

c_exec = focus.simple_bind(ctx=mx.cpu(), data=(4, 8), label=(4,), grad_req='write')

c_exec.copy_params_from(arg_params=args)

c_exec.forward()
c_exec.backward()
print(c_exec.grad_arrays)

I don’t understand how the model calculates gradients with wrong indexes. What’s going on with this code?


#2

The gradient of cross entropy w.r.t sotmax output is:
gradient=output−label
In the case where label=1234, what happens is that the ground truth value for each of the 8 elements of the prediction is 0.
So the gradient is actually the output. It is confirmed here:

args['label']=(mx.nd.ones((4,))*1234).astype('int32') # invalid label index
c_exec = focus.simple_bind(ctx=mx.cpu(), data=(4, 8), label=(4,), grad_req='write')

c_exec.forward(**args)
c_exec.backward()
print(c_exec.output_dict)
print(c_exec.grad_dict)

{'softmaxoutput0_output': 
[[0.15378468 0.08179748 0.11394583 0.08715329 0.15388192 0.16731301
  0.12735543 0.1147683 ]
 [0.14141993 0.1324424  0.11304702 0.11670923 0.1781388  0.08622951
  0.13038214 0.10163102]
 [0.11895253 0.14455761 0.14277396 0.12517804 0.07936542 0.12458669
  0.10526115 0.15932466]
 [0.13836394 0.09536436 0.09556355 0.10491768 0.13265201 0.15124239
  0.10978047 0.17211556]]
<NDArray 4x8 @cpu(0)>}
{'data': 
[[ 0.15378468  0.08179748  0.11394583  0.08715329 -0.8461181   0.16731301
   0.12735543  0.1147683 ]
 [ 0.14141993  0.1324424   0.11304702  0.11670923 -0.8218612   0.08622951
   0.13038214  0.10163102]
 [ 0.11895253  0.14455761  0.14277396  0.12517804 -0.92063457  0.12458669
   0.10526115  0.15932466]
 [ 0.13836394  0.09536436  0.09556355  0.10491768 -0.86734796  0.15124239
   0.10978047  0.17211556]]

I agree that when the label is out of range of the index it shouldn’t be silently accepted. I am not aware of a use-case where it would make sense. I think a warning or even an error would be a better user experience here.


#3

Thank you for replying.:smile: I also agree that an error would be better for users.