Gluon multi-input Block

Hi,

I have been trying to understand how to build a custom Block in Gluon to feed in multiple input data.

In keras for example I found what I was looking for: https://keras.io/getting-started/functional-api-guide/#multi-input-and-multi-output-models, where different layers are concatenated.

How would I be able to have this functionality in Gluon?

What I currently have is a very simple custom block that stores two dense layers and concatenates their activations on the forward pass:

class ConcatLayer(gluon.nn.HybridBlock):
    def __init__(self, **kwargs):
        gluon.nn.HybridBlock.__init__(self)
        with self.name_scope():
            self.a = gluon.nn.Dense(3)
            self.b = gluon.nn.Dense(5)

    def hybrid_forward(self, F, first_input, *args, **kwargs):
        return F.concat(self.a(first_input), self.b(first_input), dim=1)

This is working, but is currently feeding the same input to the two dense layers instead of having two, or an abitrary number of inputs. If net is a HybridSequential model, how would I call this with two, or generally more, parameters?
With a model defined as:

    net = gluon.nn.HybridSequential()
    net.add(ConcatLayer())
    net.hybridize()
    net.collect_params().initialize(mx.init.Normal(sigma=1.), ctx=self.model_ctx)

just net(a, b)? I tried that and got an error in _get_graph, block.py:742 (TypeError: hybrid_forward() takes 3 positional arguments but 4 were given). I am using MXNet 1.3.

Any ideas of how to have a multiple input block in Gluon? Ideally, the block should be able to take other blocks as inputs in the constructor and make use of them in the forward pass.

Thanks,
Philipp

Hi @phschmid,

easy :slight_smile: , you just have to declare an additional argument to hybrid_forward

import mxnet as mx 
from mxnet import nd, gluon

class ConcatLayer(gluon.nn.HybridBlock):
    def __init__(self, **kwargs):
        gluon.nn.HybridBlock.__init__(self, **kwargs) # attention here to pass kwargs to initialization of hybridblock
        with self.name_scope():
            self.a = gluon.nn.Dense(3)
            self.b = gluon.nn.Dense(5)

    def hybrid_forward(self, F, input_1, input_2): # You don't really need *args, **kwards in this case
        out1 = self.a(input_1)
        out2 = self.b(input_2)

        result = F.concat(out1,out2, dim=1)
        return result


net = ConcatLayer() # You don't need Sequential or HybridSequential for a single layer
net.hybridize(static_alloc=True,static_shape=True) # lightning gluon speed :) 
some_ctx = mx.cpu() # modified thisf rom your code 
net.initialize(mx.init.Normal(sigma=1.), ctx=some_ctx) # you don't need collect_params anymore for initializing, change in ctx definition

batch = 32
input1 = nd.random.uniform(shape = [batch, 10])
input2 = nd.random.uniform(shape = [batch, 20])

out = net(input1,input2)
print (out.shape) 
#(prints (32,8))

Cheers,
Foivos

7 Likes

Hi @feevos,

thank you for your quick reply, indeed this is working like charm. :slight_smile:

I was able to reproduce the error I mentioned in my post by changing your network definition from

net = ConcatLayer()
net.hybridize(static_alloc=True,static_shape=True)

to

net = gluon.nn.HybridSequential()
net.add(ConcatLayer())
net.hybridize()

I’ve declared the network like this because ultimately I will need multiple layers.

Do I need to change how I am calling the network with the inputs if I am using a Sequential model?

Thanks,
Philipp

1 Like

Hi @phschmid,

indeed I verify that If you try to encapsulate your network in a HybridSequential it gives the error you mention. I recommend to avoid using HybridSequential, and just write up another HybridBlock where you’ll put inside there everything you want. I hardly ever use HybridSequential, unless I want to add more layers which are standard. For the sake of example, say you want to stack batch norm and conv2D layers, I would write my code like this (assuming you’ve already defined the ConcatLayer previously):

class CustomNet(HybridBlock):
    def __init__(self, depth, **kwargs):
        HybridBlock.__init__(self,**kwargs)


        with self.name_scope():
            # Declare here your initial ConcatLayer
            self.concat = ConcatLayer()

            # Stack here any more layers  you want to have after (or before?) your ConcatLayer
            temp = gluon.nn.HybridSequential() # A temporary variable

            # That we add a bunch of layers, repeating in total depth times. 
            for _ in range(depth):
                temp.add(gluon.nn.Conv2D(.....,use_bias=False)) # Add some arguments on Conv2D
                temp.add(gluon.nn.BatchNorm())

           # Pass the temporary variable to a class member. 
           self.main = temp


   def hybrid_forward(self,F,input1,input2):

        # do your thing with multiple layers
        out = self.concat(input1,input2)
        # then pass the output (single layer) to a HybridSequential standard stack of layers
        out = self.main(out)

       return out

I haven’t tested the code, but you get the idea, should be easy to make it work if there is a minor bug. Look also at these topics, for more custom blocks.
a, b, c, d.

Please also note, that HybridSequential and Sequential behave also as containers (you can access the layers you put inside them with list indexing, see this). Very helpful if you want to write loops inside your hybrid_forward call.

If you want post the network you want to have, and people here can help.

Cheers,
Foivos

2 Likes

Thanks @feevos for the thorough sample and follow up links. I think this does the trick for me.

Best,
Philipp

1 Like

I know this post is kind of old, but FYI you can actually use a HybridSequence. You just need an object that accepts multiple arguments and applies them in the first layer and then continues normally for the rest of the layers.

The custom sequential network object:

from mxnet.gluon.nn import HybridSequential

class SequentialMultiInput(HybridSequential):

    def hybrid_forward(self, F, *args):
        x = list(args)
        first = True
        for block in self._children.values():
            if first:
                x = block(*x)
                first = False
            else:
                x = block(x)
        return x

The concatenation block:

from mxnet.gluon import HybridBlock

class ConcatLayer(HybridBlock):

    def __init__(self, *args, **kwargs):
        super(ConcatLayer, self).__init__(**kwargs)
        self.layers = []
        for layer in args:
            self.register_child(layer)
            self.layers.append(layer)

    def hybrid_forward(self, F, *args):
        inputs = list(args)
        outputs = []
        for (layer, input) in zip(self.layers, inputs):
            outputs.append(layer(input))
        return F.concat(*outputs)

And finally, your network:

from mxnet.gluon.nn import Dense

net = SequentialMultiInput("merge_")
with net.name_scope():
    net.add(ConcatLayer(Dense(3), Dense(4), Dense(5)))
    net.add(Dense(2))
net.hybridize()

For training data I exemplify using a random data generator

from mxnet.gluon.data import DataLoader, Dataset
from mxnet import nd
import numpy as np
import random

class LogsDataset(Dataset):
    def __init__(self):
        self.len = int(1024)

    def __getitem__(self, idx):
        feature01 = nd.array(np.random.rand(1, 16, 16))
        feature02 = nd.array(np.random.rand(100, 8))
        feature03 = nd.array(np.random.rand(16))

        label = nd.full(1, random.randint(0, 1), dtype="float32")

        return feature01, feature02, feature03, label

    def __len__(self):
        return self.len

train_data = DataLoader(LogsDataset(), batch_size=64)

You train and export the network like this:

from mxnet import autograd
from mxnet.gluon import Trainer
from mxnet.gluon.loss import SoftmaxCrossEntropyLoss
import mxnet as mx

ctx = mx.cpu()
net.initialize()
softmax_loss = SoftmaxCrossEntropyLoss()
trainer = Trainer(net.collect_params(), optimizer="adam")

for epoch in range(5):
    for idx, (X01, X02, X03, y) in enumerate(train_data):
        with autograd.record():
            output = net(X01, X02, X03)
            loss = softmax_loss(output, y)

        loss.backward()
        trainer.step(64)

net.export("net")

Hope this helps.

4 Likes