Custom loss with sampling (Pose-ea associative embedding push-pull loss)

Hi,

I need to implement a simpler version of the loss in the following papers (using gluon):

The first one has a pytorch implementation here https://github.com/umich-vl/pose-ae-train.

The network is a segmentation network (U-net) and its output is an embedding vector for each pixel in the input image.
For training we have a “tag” label for each of the pixels in the output of the network. If two tags are the same, then we want the L2 distance of their embedding vectors to be close to zero. If the tags are not the same we want the L2 distance between the embedding vectors to be big.

The loss input is

  • network predictions (embedding vector per pixel):
    • feature map of size: [N] x [channels] x [height] x [width]
  • labels “tags” mask:
    • array of ints of size: [N] x [width] x [height]

To practically calculate the loss we would like to sample a small set of pixels (x1, x2, x3, ...) from the network output feature map and to apply loss to pairs of pixels according to their tag values. If the tags values are the same we apply L2_loss(x1-x2) , if not we apply Exp(-L2_loss(x1-x2)).

Can anyone point me to where to start?

I haven’t looked at the second paper, but for the first one, I think you can use gather_nd() to gather hk(xnk) for all xnk values and then calculate the reference embedding (i.e. mean embedding) for each object. I’ll take a look at the second paper tomorrow. Let me know if you need help using gather_nd (every time I struggle with it myself!).

I’ve implemented the “pull” loss which based on a tag mask (or indices list) gathers the embedding vectors and “pulls” them toward their mean. From here it is easy to implement the other parts.

Now, how can I make it really fast? and what is the best way to extend it to batch and multi-gpu? do I need to write a cuda kernel?

minimal code for the “pull” loss:

import mxnet as mx
from mxnet import autograd
import numpy as np

def gather_embedding_vectors_of_tag(input_data, labels, tag_value, ctx, verbose=False):
    # input:
    #   input_data - floats ndarray, size: [channels] x [height] x [width]
    #   labels - int array with tag values, size: [height] x [width]
    #   tag_value - int, size: scalar
    # output:
    #   vectors correspond to tag value, size: [channels] x [num_of_corresponding_vectors]
    
    tag_indicies_map = labels == tag_value
    indices_list = np.nonzero(tag_indicies_map.asnumpy())
    indices_list = mx.nd.array(indices_list).as_in_context(ctx)

    #TODO: we can skip directly to this part by saving the indices list
    gathered_vectors = mx.nd.gather_nd(input_data ,indices_list)
    return gathered_vectors.reshape((input_data.shape[0], -1))


def get_vectors_mean(vectors):
    mean = mx.nd.mean(vectors, axis=1, keepdims=True)
    mean = mx.nd.repeat(mean, repeats=vectors.shape[1], axis=1)
    return mean

np.random.seed(12345)
ctx = mx.gpu(0)
input_width = 3
input_height = 3
input_channels = 2
batch_size = 1
input_shape = [batch_size, input_channels, input_height, input_width]
input_data = np.random.randint(10, size=input_shape)
labels = np.zeros(shape=(batch_size, input_channels, input_height, input_width))
tag_value = 2
labels[0,:,0:2,0:2] = tag_value

input_data = mx.nd.array(input_data).as_in_context(ctx)
labels = mx.nd.array(labels).as_in_context(ctx)
print("input_data:")
print(input_data)
print("labels:")
print(labels)

l2_loss = mx.gluon.loss.L2Loss()
for i in range(1000):
    with autograd.record():
        input_data.attach_grad()        
        gathered_vectors = gather_embedding_vectors_of_tag(input_data[0], labels[0], tag_value, ctx)
        mean_value = get_vectors_mean(gathered_vectors)
        loss = l2_loss(gathered_vectors, mean_value)
        loss.backward(retain_graph=True)
    input_data =  input_data -0.5 * input_data.grad

print(input_data.grad)
print(input_data.asnumpy().round().astype(np.int))

You don’t want to create any numpy dependency in your computational graph. That means that the indices_list should be for each sample as a pre-processing step and passed to the AE function, not calculated inside the AE function. Also your implementation doesn’t seem correct. You’re passing in an array of (2, 3, 3) shape for data and label to gather_embedding_vectors_of_tag() and creating indices based on any label that matches tag, regardless of the channel. I think you should only create indices based on one channel (as stated in the paper).

The above implementation requires the data to be de-batchified and processed one image at a time. The result can be re-batchified for upper layers. This could be the most efficient if you’re dealing with small pixels to be gathered (for example in body joint detection).

An alternative implementation, which may work better for larger pixels per tag (e.g. instance segmentation), could use scatter_nd() to create a mask for each tag. You’d then multiply the data by this mask, calculate the sum of the result and divide by sum of the mask to get the reference embedding. You would create the mask from the label during data pre-processing step (e.g. using transform function passed into a gluon.data.DataLoader()). This implementation can take advantage of GPU parallel processing better.

A few other general rules to follow:

  • attach gradient to your params outside of the optimization loop
  • call backward() outside of the autograd scope
1 Like

Hi, if I understand correctly, you have a network prediction y, ground truth y_hat and two ‘tags’ or labels t1, t2 with the shape of [N] x [H] x [W]. The loss is calculated differently based on t1 is equal to t2 or not. I think MXNet built in operator should be able to handle this:

loss = F.where((t1==t2).expand_dims(1),  -(y-y_hat)**2, F.exp((y-y_hat)**2))
loss = loss.mean()