Hi there,
I am trying to define my own loss function that implements pair-wise contrastive loss. I want to compute pair-wise feature comparison given a mini-batch with some sample of the same class, and then add the contrastive loss on the distances.
I saw that in the Gluon Loss only TripletLoss is available [1]
The custom loss function is derived from mxnet.gluon.loss.Loss
as:
import mxnet
from mxnet.gluon.loss import Loss
from mxnet.gluon.loss import _apply_weighting
from mxnet.gluon.loss import _reshape_like
class MyLoss(Loss):
def __init__(self, batch_size, axis=-1, sparse_label=True, from_logits=False, weight=None,
batch_axis=0, **kwargs):
super(MyLoss, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
self._from_logits = from_logits
import numpy as np
self._upper_tr_index = np.triu_indices(batch_size,1)
self.eps = 1e-7
self.soft_max=False
def hybrid_forward(self, F, pred, label, sample_weight=None):
if not self.soft_max:
## My custom loss function
dist_feat = self.pairwise_distance(F,pred)
sparse_feat = dist_feat[self._upper_tr_index]
## from label classes 1...N get same not same {0,1}
## then Here add contrastive loss
#loss = ...
def pairwise_distance(self, F, label):
lnorm = F.sum(F.power(label,2),axis=1)
lnorm_tile = F.tile(lnorm,reps=(label.shape[0],1)).T
llt = F.dot(label,label,transpose_b=True)
dist_label = lnorm_tile + lnorm_tile.T - 2*llt
return dist_label
MXnet version:
In [3]: mx.__version__
Out[3]: '0.12.1'
Is this the right way to do it?
Thanks,
[1] https://mxnet.incubator.apache.org/api/python/gluon/loss.html#mxnet.gluon.loss.TripletLoss