How to return subset of a gluon hybrid block?


thank your for providing the MXNET/Gluon framework to the community.
I’d like to know how to return a subset of a HybridBlock based on a constant indices vector.

Here’s a minimal example what I’m trying to do:

class ConvNet(HybridBlock):

    def __init__(self, name: str, indices_vector: tuple):
        :param name: Name of the network
        :param indices_vector: 1D list e.g. [0, 2, 3] which defines the values to select after the forward pass.
            Note, this must be coherent with the input size of the network (e.g. 8x8).
        super(ConvNet, self).__init__(prefix=name + "_")
        self.body = HybridSequential(prefix="")

        with self.name_scope():
            self.body.add(Conv2D(channels=1, kernel_size=(1, 1), use_bias=False))

            self.indices_vector = mx.gluon.Constant('const', indices_vector)

    def hybrid_forward(self, F, x):
        Compute forward pass
        :param F: MXNET-handle
        :param x: Input data to the block
        :return: Activation maps of the block
        x = self.body(x)
        return F.take(x, self.indices_vector, axis=1)

This however throws the Exception:

TypeError: hybrid_forward() got an unexpected keyword argument 'self.indices_vector'

I also tried to create a MXNET constant based on this issue:

but couldn’t get it to work that way.

I know that you can select a subset afterwards in numpy, but I’d like to avoid this because of additional memory and runtime overhead.


If you do not want/need to hybridize your network, it should work if self.indices_vector is of type ndarray.
Otherwise one ugly workaround is to create an ndarray with your indices, and a symbol object and give take() one of the two, conditioned on type of F.

I guess there is a cleaner solution but thats all I can think of right away.


Also, your error actually seems to imply that hybrid_forward is called in an unfitting way.


Thank you for the reply Adrian.
I managed to get it to work using the MXNET symbol API and a Gluon Symbol Block.