class torch.nn.EmbeddingBag(num_embeddings: int, embedding_dim: int, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = 'mean', sparse: bool = False, _weight: Optional[torch.Tensor] = None, include_last_offset: bool = False)
[source]
Computes sums or means of ‘bags’ of embeddings, without instantiating the intermediate embeddings.
For bags of constant length and no per_sample_weights
, this class
mode="sum"
is equivalent to Embedding
followed by torch.sum(dim=0)
,mode="mean"
is equivalent to Embedding
followed by torch.mean(dim=0)
,mode="max"
is equivalent to Embedding
followed by torch.max(dim=0)
.However, EmbeddingBag
is much more time and memory efficient than using a chain of these operations.
EmbeddingBag also supports per-sample weights as an argument to the forward pass. This scales the output of the Embedding before performing a weighted reduction as specified by mode
. If per_sample_weights`
is passed, the only supported mode
is "sum"
, which computes a weighted sum according to per_sample_weights
.
max_norm
is renormalized to have norm max_norm
.max_norm
option. Default 2
.False
. Note: this option is not supported when mode="max"
."sum"
, "mean"
or "max"
. Specifies the way to reduce the bag. "sum"
computes the weighted sum, taking per_sample_weights
into consideration. "mean"
computes the average of the values in the bag, "max"
computes the max value over each bag. Default: "mean"
True
, gradient w.r.t. weight
matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not supported when mode="max"
.True
, offsets
has one additional element, where the last element is equivalent to the size of indices
. This matches the CSR format.~EmbeddingBag.weight (Tensor) – the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from .
Inputs: input (LongTensor), offsets (LongTensor, optional), and
per_index_weights
(Tensor, optional)
If input
is 2D of shape (B, N)
,
it will be treated as B
bags (sequences) each of fixed length N
, and this will return B
values aggregated in a way depending on the mode
. offsets
is ignored and required to be None
in this case.
If input
is 1D of shape (N)
,
it will be treated as a concatenation of multiple bags (sequences). offsets
is required to be a 1D tensor containing the starting index positions of each bag in input
. Therefore, for offsets
of shape (B)
, input
will be viewed as having B
bags. Empty bags (i.e., having 0-length) will have returned vectors filled by zeros.
to indicate all weights should be taken to be 1
. If specified, per_sample_weights
must have exactly the same shape as input and is treated as having the same offsets
, if those are not None
. Only supported for mode='sum'
.
Output shape: (B, embedding_dim)
Examples:
>>> # an Embedding module containing 10 tensors of size 3 >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum') >>> # a batch of 2 samples of 4 indices each >>> input = torch.LongTensor([1,2,4,5,4,3,2,9]) >>> offsets = torch.LongTensor([0,4]) >>> embedding_sum(input, offsets) tensor([[-0.8861, -5.4350, -0.0523], [ 1.1306, -2.5798, -1.0044]])
classmethod from_pretrained(embeddings: torch.Tensor, freeze: bool = True, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = 'mean', sparse: bool = False, include_last_offset: bool = False) → torch.nn.modules.sparse.EmbeddingBag
[source]
Creates EmbeddingBag instance from given 2-dimensional FloatTensor.
True
, the tensor does not get updated in the learning process. Equivalent to embeddingbag.weight.requires_grad = False
. Default: True
None
2
.False
."mean"
False
.False
.Examples:
>>> # FloatTensor containing pretrained weights >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight) >>> # Get embeddings for index 1 >>> input = torch.LongTensor([[1, 0]]) >>> embeddingbag(input) tensor([[ 2.5000, 3.7000, 4.6500]])
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.7.0/generated/torch.nn.EmbeddingBag.html