Pytorch -> mxnet: much lower recall with the same hyperparameters

I implementing a novel metric learning algorithm from this paper: http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf

The authors released their implementation which I am rewriting in mxnet: https://github.com/MalongTech/research-ms-loss
I ran it and verified that it does produce the numbers that are in the paper.

However, my mxnet code does not. In fact, if I match the exact same hyperparameters (LR, WD, initialization, etc) I get significantly worse results. Hence, I am wondering if there is something in pytorch that is specific and I forgot to match?

Probably the key part of the code is the loss function. Here is torch one: https://github.com/MalongTech/research-ms-loss/blob/master/ret_benchmark/losses/multi_similarity_loss.py#L15

And this is my implementation:

class MultisimilarityLoss(loss.Loss):
    def __init__(self, threshold=0.5, margin=0.1, positive_scale=2.0, negative_scale=40.0, epsilon=1e-5,
             weight=None, batch_axis=0, **kwargs):
        super(MultisimilarityLoss, self).__init__(weight, batch_axis, **kwargs)
        self._threshold = threshold
        self._margin = margin
        self._scale_pos = positive_scale
        self._scale_neg = negative_scale
        self._epsilon = epsilon

    def hybrid_forward(self, F, embeddings, labels):
        # Embeddings are L2 normalized
        sim_mat = F.dot(embeddings, embeddings.transpose())  # BxB
        adjacency = F.broadcast_equal(labels.expand_dims(1), labels.expand_dims(0))
        neg_adjacency = 1 - adjacency

        pos_pairs = sim_mat * adjacency
        pos_pairs = pos_pairs * (pos_pairs < (1 - self._epsilon))  # remove self
        neg_pairs = sim_mat * neg_adjacency

        max_negative = F.max(neg_pairs, axis=1, keepdims=True)
        # Select minimum in each positive row, use a bit of trick to avoid selecting zeroes
        min_positive = F.min((F.broadcast_mul(F.max(pos_pairs, axis=1, keepdims=True) * 10, (pos_pairs == 0))) + pos_pairs,
                         axis=1, keepdims=True)
        neg_pairs = F.broadcast_greater(neg_pairs + self._margin, min_positive) * neg_pairs
        pos_pairs = F.broadcast_lesser(pos_pairs - self._margin, max_negative) * pos_pairs

        pos_loss = 1.0 / self._scale_pos * F.log(
            1 + F.sum(F.exp(-self._scale_pos * (pos_pairs - self._threshold)) * (pos_pairs != 0), axis=1))

        neg_loss = 1.0 / self._scale_neg * F.log(
            1 + F.sum(F.exp(self._scale_neg * (neg_pairs - self._threshold)) * (neg_pairs != 0), axis=1))

        loss = pos_loss + neg_loss
        loss = F.mean(loss)

        return loss

Given the loss has a mean I am using

trainer.step(1)

I also try to match a lower LR on the backbone and freezing BN layers with this:

for v in net.base_net.collect_params().values():
    setattr(v, 'lr_mult', 0.1)
    if 'batchnorm' in v.name or 'bn_' in v.name:
        v.grad_req = 'null'

Interestingly, I do get somewhat better results if I increase the LR 10-20 fold, but in the end the recall@1 is still much lower. Any idea why my code is not producing similar results?

Using MXNet 1.5.post0 from pip with cuda 10