MXNet Forum

Rewrite the GRUCell, Error:TypeError: forward() takes 3 positional arguments but 4 were given


#1

In order to add some graph convolution operations to GRUCell, I have to rewrite the GRUCell of Mxnet.
Technically, I should add a parameter of the input, i.e, adjacency matrix(adj). Unfortunately, I got the error “TypeError: forward() takes 3 positional arguments but 4 were given.

 class mgcn_grucell(HybridRecurrentCell):
    def __init__(self, hidden_size,
             i2h_weight_initializer=None, h2h_weight_initializer=None,
             i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
             input_size=0, prefix=None, params=None):
    super(mgcn_grucell, self).__init__(prefix=prefix, params=params)
    //mgcn_cell and gru_cell differ only in the terms of that there is an added parameter adj in the rewritten function hybrid_forward().

    def hybrid_forward(self, F, inputs, states,adj):
    //......

 class HybridRecurrentCell(RecurrentCell, HybridBlock):
          """HybridRecurrentCell supports hybridize."""
          def __init__(self, prefix=None, params=None):
         super(HybridRecurrentCell, self).__init__(prefix=prefix, params=params)
          def hybrid_forward(self, F, x, *args, **kwargs):
              raise NotImplementedError

 class RecurrentCell(Block):
       //.......
       //forword added adj
       def forward(self, inputs, states,adj):
         self._counter += 1
         return super(RecurrentCell, self).forward(inputs, states,adj)

Actually, HybridRecurrentCell inherits from RecurrentCell. It still doesn’t work when I add the adj to the forward() in RecurrentCell .

So, how can I figure it out?


#2

This is the hybrid_forward() of GRUCell:

    def hybrid_forward(self, F, inputs, states, i2h_weight,
                       h2h_weight, i2h_bias, h2h_bias):

Your signature is missing all the weights. Am I missing something?