Instance segmentation ground-truth representation (and gather_nd)


#1

Hi,

I have an instance segmentation network, where the ground-truth segmentation masks are stored as grayscale images. Each tag value in the ground-truth represents an instance (1,2,3,4,5,…,#instances). I then extract the indices of each instance and extract its pixels from another map:

tags = mx.nd.array([[0,1,1,1,0],
                    [0,1,1,0,0],
                    [0,1,0,2,2],
                    [0,0,0,2,2]])

tag_indices_list = list()
for tag_value in range(tags.max().asscalar()):
    one_hot_tag_values = (tags == (tag_value + 1)).astype(np.uint8).asnumpy()
    indices = cv2.findNonZero(one_hot_tag_values).squeeze()
    tag_indices_list.append(mx.nd.array(np.fliplr(indices)).transpose())


other_map = mx.nd.array([[0,1,2,3,0],
                         [0,4,5,0,0],
                         [0,6,0,1,2],
                         [0,0,0,3,4]])

gathered_vectors = mx.nd.gather_nd(other_map, tag_indices_list[0])
print(gathered_vectors)

[ 1.  2.  3.  4.  5.  6.]
<NDArray 6 @cpu(0)>
  • The “other_map” matrix could be an RGB image.

The problem is, that this line:

one_hot_tag_values = (tags == (tag_value + 1)).astype(np.uint8).asnumpy()

takes too long and become a bottleneck for my training.

I’ve tried to save the “one_hot” maps, but it takes a lot of memory when you have many instances. Alos, I’ve tried “joblib” for multithreading which was slower than a for loop.

I was thinking next to implement a sparse matrix implementation, and maybe to approximate the instances by polygons and somehow fast decode the indices. Another option is to get the indices using cpp.

Any suggestions?


#2

@Oron_Anschel, could you explain what you are trying to achieve by doing this? that would be helpful. Sometimes there is a better way when you consider the end goal. Try using pip install mxnet-mkl --pre that should give you nightly build with the mkldnn version that has some ndarray operations optimized on CPU.

For this specific task, you can try this and see if it is faster:

tags = mx.nd.array([[0,1,1,1,0],
                    [0,1,1,0,0],
                    [0,1,0,2,2],
                    [0,0,0,2,2]])
other_map = mx.nd.array([[0,1,2,3,0],
                     [0,4,5,0,0],
                     [0,6,0,1,2],
                     [0,0,0,3,4]])
​
tags = tags.reshape((-1,))
other_map = other_map.reshape((-1,))
for tag_value in range(int(tags.max().asscalar())):
    one_hot = (tags == (tag_value + 1))
    values = other_map[one_hot.topk(k=int(one_hot.sum().asscalar()))]
    print(values)
[ 1.  2.  3.  4.  5.  6.]
<NDArray 6 @cpu(0)>

[ 1.  2.  3.  4.]
<NDArray 4 @cpu(0)>