class torch.nn.EmbeddingBag(num_embeddings: int, embedding_dim: int, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = 'mean', sparse: bool = False, _weight: Optional[torch.Tensor] = None, include_last_offset: bool = False) [source]
Computes sums or means of ‘bags’ of embeddings, without instantiating the intermediate embeddings.
For bags of constant length and no per_sample_weights, this class
mode="sum" is equivalent to Embedding followed by torch.sum(dim=0),mode="mean" is equivalent to Embedding followed by torch.mean(dim=0),mode="max" is equivalent to Embedding followed by torch.max(dim=0).However, EmbeddingBag is much more time and memory efficient than using a chain of these operations.
EmbeddingBag also supports per-sample weights as an argument to the forward pass. This scales the output of the Embedding before performing a weighted reduction as specified by mode. If per_sample_weights` is passed, the only supported mode is "sum", which computes a weighted sum according to per_sample_weights.
max_norm is renormalized to have norm max_norm.max_norm option. Default 2.False. Note: this option is not supported when mode="max"."sum", "mean" or "max". Specifies the way to reduce the bag. "sum" computes the weighted sum, taking per_sample_weights into consideration. "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean"
True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not supported when mode="max".True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.~EmbeddingBag.weight (Tensor) – the learnable weights of the module of shape (num_embeddings, embedding_dim) initialized from .
Inputs: input (LongTensor), offsets (LongTensor, optional), and per_index_weights (Tensor, optional)
If input is 2D of shape (B, N),
it will be treated as B bags (sequences) each of fixed length N, and this will return B values aggregated in a way depending on the mode. offsets is ignored and required to be None in this case.
If input is 1D of shape (N),
it will be treated as a concatenation of multiple bags (sequences). offsets is required to be a 1D tensor containing the starting index positions of each bag in input. Therefore, for offsets of shape (B), input will be viewed as having B bags. Empty bags (i.e., having 0-length) will have returned vectors filled by zeros.
to indicate all weights should be taken to be 1. If specified, per_sample_weights must have exactly the same shape as input and is treated as having the same offsets, if those are not None. Only supported for mode='sum'.
Output shape: (B, embedding_dim)
Examples:
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([1,2,4,5,4,3,2,9])
>>> offsets = torch.LongTensor([0,4])
>>> embedding_sum(input, offsets)
tensor([[-0.8861, -5.4350, -0.0523],
[ 1.1306, -2.5798, -1.0044]])
classmethod from_pretrained(embeddings: torch.Tensor, freeze: bool = True, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = 'mean', sparse: bool = False, include_last_offset: bool = False) → torch.nn.modules.sparse.EmbeddingBag [source]
Creates EmbeddingBag instance from given 2-dimensional FloatTensor.
True, the tensor does not get updated in the learning process. Equivalent to embeddingbag.weight.requires_grad = False. Default: True
None
2.False."mean"
False.False.Examples:
>>> # FloatTensor containing pretrained weights >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight) >>> # Get embeddings for index 1 >>> input = torch.LongTensor([[1, 0]]) >>> embeddingbag(input) tensor([[ 2.5000, 3.7000, 4.6500]])
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.7.0/generated/torch.nn.EmbeddingBag.html