Implementing contrastive loss

Hi there,
I am trying to define my own loss function that implements pair-wise contrastive loss. I want to compute pair-wise feature comparison given a mini-batch with some sample of the same class, and then add the contrastive loss on the distances.
I saw that in the Gluon Loss only TripletLoss is available [1]

The custom loss function is derived from mxnet.gluon.loss.Loss as:

import mxnet
from mxnet.gluon.loss import Loss
from mxnet.gluon.loss import _apply_weighting
from mxnet.gluon.loss import _reshape_like
    

class MyLoss(Loss):
    def __init__(self, batch_size, axis=-1, sparse_label=True, from_logits=False, weight=None,
                 batch_axis=0, **kwargs):
        super(MyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label
        self._from_logits = from_logits
        import numpy as np
        self._upper_tr_index = np.triu_indices(batch_size,1)
        self.eps = 1e-7
        self.soft_max=False


def hybrid_forward(self, F, pred, label, sample_weight=None):
    if not self.soft_max:
        ## My custom loss function 
        dist_feat = self.pairwise_distance(F,pred)
        sparse_feat = dist_feat[self._upper_tr_index]
        ## from label classes 1...N get same not same {0,1}
        ## then Here add contrastive loss
        #loss = ...
def pairwise_distance(self, F, label):
    lnorm = F.sum(F.power(label,2),axis=1)
    lnorm_tile = F.tile(lnorm,reps=(label.shape[0],1)).T
    llt = F.dot(label,label,transpose_b=True)
    dist_label = lnorm_tile + lnorm_tile.T - 2*llt
    return dist_label

MXnet version:

In [3]: mx.__version__
Out[3]: '0.12.1'

Is this the right way to do it?

Thanks,

[1] https://mxnet.incubator.apache.org/api/python/gluon/loss.html#mxnet.gluon.loss.TripletLoss