Gluon gets stuck when applying custom function on model parameters

I have a piece of code that inexplicably get stuck and I cannot understand why. What’s happening in this code is that I define a new autograd.function with a custom backward pass and apply it on the model parameters. If I were to apply this function on the model inputs/outputs there is no problem, but when applied on the parameters this code gets stuck in the optimizer.step function. Any hints on what’s wrong and how to fix it are appreciated (yes, I know I can get this exact functionality otherwise, this is just a test function to make the code simpler, in practice I will have a backward step that’s needs to be defined by me)

import mxnet as mx
import mxnet.ndarray as nd
import mxnet.gluon.nn as nn

# Custom function with saved tensor being the input (function doesn't matter)
class myfun( mx.autograd.Function):
    def forward( self, x):
        self.save_for_backward( x)
        return x*x

    def backward( self, dx):
        x, = self.saved_tensors
        return dx * 2*x

# A layer using that function on its parameter
class mylinear( nn.Block):
    def __init__( self, ins, outs):
        super( mylinear, self).__init__()
        self.weight = self.params.get( 'weight', shape=(outs,ins))

    def forward( self, x):
        w = myfun()(
        return x, w.transpose())

# Training loop
def trainit( model, bt=1024):
    opt = mx.gluon.Trainer( model.collect_params(), 'SGD', {'learning_rate':1e-3})
    for ep in range( 10):
        # Make up some data
        x = nd.random.normal( shape=(bt,10))
        y = nd.random.normal( shape=(bt,2))

        # Get loss and update
        with mx.autograd.record( train_mode=True):
            l = (model( x) - y).abs().mean()
        opt.step( bt) # <- stuck here after the first pass
        print( 'Epoch:', ep, 'loss:', l.asscalar())

trainit( mylinear( 10, 2))

There are multiple issues in your code: There are no gradients attached to x and y. You need to add the following:

    x = nd.random.normal( shape=(bt,10))
    y = nd.random.normal( shape=(bt,2))

Also replace the class myfun with:

def f(x):
    return x*x

Thanks, but if I attach the gradients it still doesn’t work. I cannot change this function, as I mentioned above this is not the function I use, it is just for illustration. I need to implement a custom backwards function on something applied on the parameters. Using a function like you suggested doesn’t allow me to have a custom backwards step.