Skip to content Skip to sidebar Skip to footer

How To Return Intermideate Gradients (for Non-leaf Nodes) In Pytorch?

My question is concerning the syntax of pytorch register_hook. x = torch.tensor([1.], requires_grad=True) y = x**2 z = 2*y x.register_hook(print) y.register_hook(print) z.backwar

Solution 1:

I think you can use those hooks to store the gradients in a global variable:

grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y

x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))

z.backward()

But you most likely also need to remember the corresponding tensor these gradients were computed for. In that case, we slightly extend above using a dict instead of list:

grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y

defstore(grad,parent):
    print(grad,parent)
    grads[parent] = grad.clone()

x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))

z.sum().backward()

Now you can, for example, access tensor y's grad simply using grads[y]

Post a Comment for "How To Return Intermideate Gradients (for Non-leaf Nodes) In Pytorch?"