W3cubDocs

/PyTorch

enable_grad

class torch.enable_grad [source]

Context-manager that enables gradient calculation.

Enables gradient calculation, if it has been disabled via no_grad or set_grad_enabled.

This context manager is thread local; it will not affect computation in other threads.

Also functions as a decorator. (Make sure to instantiate with parenthesis.)

Example:

>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
...   with torch.enable_grad():
...     y = x * 2
>>> y.requires_grad
True
>>> y.backward()
>>> x.grad
>>> @torch.enable_grad()
... def doubler(x):
...     return x * 2
>>> with torch.no_grad():
...     z = doubler(x)
>>> z.requires_grad
True

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.7.0/generated/torch.enable_grad.html