I have a piece of code that inexplicably get stuck and I cannot understand why. What’s happening in this code is that I define a new autograd.function with a custom backward pass and apply it on the model parameters. If I were to apply this function on the model inputs/outputs there is no problem, but when applied on the parameters this code gets stuck in the optimizer.step function. Any hints on what’s wrong and how to fix it are appreciated (yes, I know I can get this exact functionality otherwise, this is just a test function to make the code simpler, in practice I will have a backward step that’s needs to be defined by me)
import mxnet as mx
import mxnet.ndarray as nd
import mxnet.gluon.nn as nn
# Custom function with saved tensor being the input (function doesn't matter)
class myfun( mx.autograd.Function):
def forward( self, x):
self.save_for_backward( x)
return x*x
def backward( self, dx):
x, = self.saved_tensors
return dx * 2*x
# A layer using that function on its parameter
class mylinear( nn.Block):
def __init__( self, ins, outs):
super( mylinear, self).__init__()
self.weight = self.params.get( 'weight', shape=(outs,ins))
def forward( self, x):
w = myfun()( self.weight.data())
return nd.dot( x, w.transpose())
# Training loop
def trainit( model, bt=1024):
model.initialize()
opt = mx.gluon.Trainer( model.collect_params(), 'SGD', {'learning_rate':1e-3})
for ep in range( 10):
# Make up some data
x = nd.random.normal( shape=(bt,10))
y = nd.random.normal( shape=(bt,2))
# Get loss and update
with mx.autograd.record( train_mode=True):
l = (model( x) - y).abs().mean()
l.backward()
opt.step( bt) # <- stuck here after the first pass
print( 'Epoch:', ep, 'loss:', l.asscalar())
trainit( mylinear( 10, 2))