class torch.set_grad_enabled(mode: bool)
[source]
Context-manager that sets gradient calculation to on or off.
set_grad_enabled
will enable or disable grads based on its argument mode
. It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation in other threads.
mode (bool) – Flag whether to enable grad (True
), or disable (False
). This can be used to conditionally enable gradients.
Example:
>>> x = torch.tensor([1], requires_grad=True) >>> is_train = False >>> with torch.set_grad_enabled(is_train): ... y = x * 2 >>> y.requires_grad False >>> torch.set_grad_enabled(True) >>> y = x * 2 >>> y.requires_grad True >>> torch.set_grad_enabled(False) >>> y = x * 2 >>> y.requires_grad False
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.7.0/generated/torch.set_grad_enabled.html