Note
Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can cause persistent states like the RNG state to be advanced than they would without checkpointing. By default, checkpointing includes logic to juggle the RNG state such that checkpointed passes making use of RNG (through dropout for example) have deterministic output as compared to non-checkpointed passes. The logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations. If deterministic output compared to non-checkpointed passes is not required, supply preserve_rng_state=False
to checkpoint
or checkpoint_sequential
to omit stashing and restoring the RNG state during each checkpoint.
The stashing logic saves and restores the RNG state for the current device and the device of all cuda Tensor arguments to the run_fn
. However, the logic has no way to anticipate if the user will move Tensors to a new device within the run_fn
itself. Therefore, if you move Tensors to a new device (“new” meaning not belonging to the set of [current device + devices of Tensor arguments]) within run_fn
, deterministic output compared to non-checkpointed passes is never guaranteed.
torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
[source]
Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model.
Specifically, in the forward pass, function
will run in torch.no_grad()
manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and the function
parameter. In the backwards pass, the saved inputs and function
is retrieved, and the forward pass is computed on function
again, now tracking the intermediate activations, and then the gradients are calculated using these activation values.
Warning
Checkpointing doesn’t work with torch.autograd.grad()
, but only with torch.autograd.backward()
.
Warning
If function
invocation during backward does anything different than the one during forward, e.g., due to some global variable, the checkpointed version won’t be equivalent, and unfortunately it can’t be detected.
Warning
If checkpointed segment contains tensors detached from the computational graph by detach()
or torch.no_grad()
, the backward pass will raise an error. This is because checkpoint
makes all the outputs require gradients which causes issues when a tensor is defined to have no gradient in the model. To circumvent this, detach the tensors outside of the checkpoint
function.
(activation, hidden)
, function
should correctly use the first input as activation
and the second input as hidden
function
Output of running function
on *args
torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, **kwargs)
[source]
A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will run in torch.no_grad()
manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.
See checkpoint()
on how checkpointing works.
Warning
Checkpointing doesn’t work with torch.autograd.grad()
, but only with torch.autograd.backward()
.
torch.nn.Sequential
or the list of modules or functions (comprising the model) to run sequentially.functions
Output of running functions
sequentially on *inputs
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.7.0/checkpoint.html