Unflatten
-
class torch.nn.modules.flatten.Unflatten(dim, unflattened_size)[source] -
Unflattens a tensor dim expanding it to a desired shape. For use with
Sequential.-
dimspecifies the dimension of the input tensor to be unflattened, and it can be eitherintorstrwhenTensororNamedTensoris used, respectively. -
unflattened_sizeis the new shape of the unflattened dimension of the tensor and it can be atupleof ints or alistof ints ortorch.SizeforTensorinput; aNamedShape(tuple of(name, size)tuples) forNamedTensorinput.
- Shape:
-
- Input: , where is the size at dimension
dimand means any number of dimensions including none. - Output: , where =
unflattened_sizeand .
- Input: , where is the size at dimension
- Parameters
-
- dim (Union[int, str]) – Dimension to be unflattened
- unflattened_size (Union[torch.Size, Tuple, List, NamedShape]) – New shape of the unflattened dimension
Examples
>>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=("N", "features")) >>> unflatten = nn.Unflatten("features", (("C", 2), ("H", 5), ("W", 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5]) -