Find out whether an NDArray can be a head in autograd.backward

Hello, I have a problem where I want to use

autograd.record(heads, head_grads)

with a list of heads. For that to work, all entries in heads must be part of a graph built up in autograd.record(), or must have a gradient attached directly.

Is there a way to find out whether this situation holds for an NDArray object? I can track this myself, sure, but my code would be so much cleaner.

To give context:

a = mx.nd.zeros((1,))
a.attach_grad()
with autograd.record():
   b = 2 * a
b.backward()

This works fine. But:

a = mx.nd.zeros((1,))
with autograd.record():
   b = 2 * a
b.backward()

This fails, and tells me that b cannot be differentiated. I totally understand that, but:
Is there a way to find out whether b is not amenable for backward, except for trying it out?
For me, this is not even a sane option, because I want to use
autograd.backward
with a whole list of heads and head_gradients.

First, I thought that
b.grad is not None
does the job. But this does not work. For example:

a = mx.nd.zeros((1,)
a.attach_grad()
with autograd.record():
   b = 2 * a
a.grad is None
==> False
b.grad is None
==> True

So this test just tells me whether I myself have attached a gradient. I suppose this is due to lazy creating of the grad variable, it exists after I have called b.backward(), but that is too late.

Thanks a lot for enlightment. I just could not find it in the NDArray API docs

@mseeger I am not sure if there’s a way to find this out without trying. One way is to simply have a try / except block around your autograd.backward(heads, gradient_head) call. If you want to actually filter your list to only have valid gradients I would suggest using the following function for example:

def can_backward(head, out_grad=None):
    try:
        head.backward(out_grad=out_grad, retain_graph=True)
        return True
    except:
        return False
a = mx.nd.zeros((1,))
with autograd.record():
    b = 2 * a * c
if (can_backward(b)):
    b.backward()
    print("Did backward")
else:
    print("Can't backward")
Can't backward

Now if we attach the gradient

a = mx.nd.zeros((1,))
a.attach_grad()
with autograd.record():
    b = 2 * a
if (can_backward(b)):
    b.backward()
    print("Did backward")
else:
    print("Can't backward")
Did backward

And to get back to what you wanted to do initially with a list of heads and gradient_heads, you can filter them like that:

heads = [b]
gradient_heads = [None]
# Filter out the heads, gradient_heads pairs that are invalid
heads, gradient_heads = [list(t) for t in zip(*filter(lambda x: can_backward(x[0], x[1]), zip(heads, gradient_heads)))
autograd.backward(heads, gradient_heads)
1 Like

Thanks. This is a solution, albeit a pretty drastic one :slight_smile:

1 Like