class torch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None) [source]
TransformerDecoder is a stack of N decoder layers
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) >>> out = transformer_decoder(tgt, memory)
forward(tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None) → torch.Tensor [source]
Pass the inputs (and mask) through the decoder layer in turn.
see the docs in Transformer class.
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.7.0/generated/torch.nn.TransformerDecoder.html