Can I reunion the prediction after forward propagation for training a model with triplet loss?


#1

The code is as follows.
The reunion part is helpful for increasing the training speed.

for epoch in range(epochs):
       train_loss = 0.
       for features, indices in tqdm.tqdm(train_data):
           a_index, p_index, n_index = zip(*indices)
           data = features.as_in_context(ctx)
           with autograd.record():
               pred = net(data)
               # reunion the prediction as anchor, positive, negative
               a = pred[a_index,:]
               p = pred[p_index,:]
               n = pred[n_index,:]

               loss_ = loss(a,p,n).sum()
           loss_.backward()
           trainer.step(batch_size, ignore_stale_grad=True)
           train_loss += loss_.asscalar()

Is the reunion reasonable? And will the loss_ update the parameters correctly?


#2

Your code is doing an indexing operation on pred ndarray, which is supported for proper backward(). I’m not really sure what you mean by “reunion”. I’m also not sure why the indices is part of the train_data. But your code looks fine.


#3

Let’s assume that one batch data is

data = [[a1,p1,n1],[a1,p1,n2],[a1,p1,n3],[a1,p1,n4],
[a1,p2,n1],[a1,p2,n2],[a1,p2,n3],[a1,p2,n4],
[a1,p3,n1],[a1,p3,n2],[a1,p3,n3],[a1,p3,n4]
]

What I plan to do is to distinct the sample.
So the forward part is only related to [a1,p1,p2,p3,n1,n2,n3,n4] once.
But, the triplet loss is calculated with each pair of the data.


#4

What you’re doing works fine as shown by this code:

a = nd.array(range(10)).reshape((5,2))
a.attach_grad()
# Check simple case
with autograd.record():
    z0 = (a ** 2).sum()

z0.backward()
print(a.grad)

OUTPUT:

[[ 0.  2.]
 [ 4.  6.]
 [ 8. 10.]
 [12. 14.]
 [16. 18.]]
<NDArray 5x2 @cpu(0)>
# Check indexed case
idx0 = (1, 1, 1)
idx1 = (0, 2, 4)
with autograd.record():
    z1 = (a[idx0, :] ** 2).sum() + (a[idx1, :] ** 3).sum()

z1.backward()
print(a.grad)

OUTPUT:

[[  0.   3.]
 [ 12.  18.]
 [ 48.  75.]
 [  0.   0.]
 [192. 243.]]
<NDArray 5x2 @cpu(0)>