I am trying to do image classification with an unbalanced data set, and I want to rescale each term of the cross entropy loss function to correct for this imbalance. For example, if I have 2 classes with 100 images in class 0 and 200 images in class 1, then I would want to weight the loss function terms involving examples from class 0 with a factor 2/3 and those terms involving class 1 with a factor 1/3. In other words, I want to compute the weighted cross entropy loss as follows given the softmax outputs and label for a given example, which I will denote by (softmax_output, label):

`f(softmax_output, label) = -label[0]*log(softmax_output[0])*(2/3) - label[1]*log(softmax_output[1])*(1/3)`

For the sake of definiteness, suppose I want to use a pretrained model on imagenet1k to do this. My idea of how to approach this so far is based on first stripping off the last layer of the network, adding back a softmax activation layer, and then using MakeLoss. Unfortunately, I have some holes in my understanding of the API, and it is difficult to find parts of the documentation that address this use case. Any help would be appreciated.

#### Incorrect Code Below to Illustrate Intended Approach

`import mxnet as mx`

`from common import data, fit, modelzoo`

`(prefix, epoch) = modelzoo.download_model('imagenet1k-resnext-101-64x4d', '/path/to/model/location')`

`sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)`

`all_layers = symbol.get_internals()`

`net = all_layers['flatten0_output']`

`num_classes = 2`

`net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc')`

`net = mx.symbol.softmax(data=net, name='softmax_activation')`

#### From here I’m not exactly sure how to proceed, but hopefully this can be corrected.

`label = mx.symbol.Variable("label")`

`ce = -labe[0]*(2.0/3)*mx.sym.log(net[0]) - label[1]*(1.0/3)*mx.sym.log(net[1])`

`loss = mx.sym.MakeLoss(ce, normalization='batch', name='weighted_cross_entropy')`

`new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k})`

`fit.fit(args = args, network = new_sym, data_loader = data.get_rec_iter, arg_params = new_args, aux_params = aux_params)`