In order to add some graph convolution operations to GRUCell, I have to rewrite the GRUCell of Mxnet.
Technically, I should add a parameter of the input, i.e, adjacency matrix(adj). Unfortunately, I got the error “TypeError: forward() takes 3 positional arguments but 4 were given.”
class mgcn_grucell(HybridRecurrentCell): def __init__(self, hidden_size, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', input_size=0, prefix=None, params=None): super(mgcn_grucell, self).__init__(prefix=prefix, params=params) //mgcn_cell and gru_cell differ only in the terms of that there is an added parameter adj in the rewritten function hybrid_forward（). def hybrid_forward(self, F, inputs, states,adj): //...... class HybridRecurrentCell(RecurrentCell, HybridBlock): """HybridRecurrentCell supports hybridize.""" def __init__(self, prefix=None, params=None): super(HybridRecurrentCell, self).__init__(prefix=prefix, params=params) def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError class RecurrentCell(Block): //....... //forword added adj def forward(self, inputs, states,adj): self._counter += 1 return super(RecurrentCell, self).forward(inputs, states,adj)
Actually, HybridRecurrentCell inherits from RecurrentCell. It still doesn’t work when I add the adj to the forward() in RecurrentCell .
So, how can I figure it out?