mx.symbol.Group() enables parallel output thus a model with multiple heads.
One examplary usage is an AlphaZero-like model head:
You can also find a visualization of the corresponding NN architecture.
You can do the same in the Gluon-API like this:
def hybrid_forward(self, F, x):
out = self.body(x)
value = self.value_head(out)
policy = self.policy_head(out)
return [value, policy]
and define a custom linear combination of the loss: