Custom Loss Function Shape mismatch!

I am migrating from tensorflow, so I am not sure if I am doing this properly. Consider the following code snippet:

    cross_entropy_cost = mx.symbol.SoftmaxOutput(data=self.logits, label = label, name=name + 'softmax')
    #softmax_label
    if self.distillation > 0: 
        
        """distill_loss = self.distillation * rmse (
                                            mx.symbol.SoftmaxActivation(self.logits / self.temperature),
                                            mx.symbol.SoftmaxActivation(self.mentor_logits / self.temperature) )"""
        distill_loss =   mx.symbol.SoftmaxOutput (data = self.logits / self.temperature,\
                                             label = mx.symbol.SoftmaxActivation(self.mentor_logits / self.temperature),
                                                name = name + '_distillation')
        loss = mx.symbol.Group( [ cross_entropy_cost ,  mx.symbol.MakeLoss(distill_loss) ] )  
    else:
        loss = cross_entropy_cost  
    return loss

if self.distillation = 0, the code works. Neither the commented out section nor the open section work and they both produce the same error. The error is as follows:

2018-01-30 01:53:14 Traceback printing
Traceback (most recent call last):
  File "start.py", line 144, in main
    trainer_student.train()
  File "----", line 14, in train
    return self.fitter.fit(args, sym, args_params=args_params, aux_params=aux_params)
  File "/home/---/fit_adapter.py", line 20, in fit
    f.fit(args, sym, data.get_rec_iter, arg_params=args_params, aux_params=aux_params)
  File "/home/---/fit.py", line 320, in fit
    monitor            = monitor)
  File "/home/---/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/module/base_module.py", line 496, in fit
    self.update_metric(eval_metric, data_batch.label)
  File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/module/module.py", line 748, in update_metric
    self._exec_group.update_metric(eval_metric, labels)
  File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/module/executor_group.py", line 588, in update_metric
    eval_metric.update_dict(labels_, preds)
  File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/metric.py", line 280, in update_dict
    metric.update_dict(labels, preds)
  File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/metric.py", line 108, in update_dict
    self.update(label, pred)
  File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/metric.py", line 388, in update
    check_label_shapes(labels, preds)
  File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/metric.py", line 41, in check_label_shapes
    "predictions {}".format(label_shape, pred_shape))
ValueError: Shape of labels 1 does not match shape of predictions 2

Looks like you are trying to compute cross entropy between two distributions. Then there is no need to use SoftmaxOutput.

Simply do

ce = mx.symbol.softmax(self.mentor_logits / self.temperature) * mx.symbol.log_softmax(self.logits / self.temperature)
loss = mx.sym.make_loss(ce)

Same error. But I think, its arising in metrics. Thanks though.

Hi, @ragavvenkatesan Do you have any update? I am also implementing a knowledge distillation-like training that uses mx.symbol.Group to get the intermediate activations and met the same error.