View source on GitHub |
Lookup embedding results, accounting for invalid IDs and empty features.
tf.nn.safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights=None, combiner='mean', default_id=None, max_norm=None, name=None )
The partitioned embedding in embedding_weights
must all be the same shape except for the first dimension. The first dimension is allowed to vary as the vocabulary size is not necessarily a multiple of num of shards.
Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs with non-positive weight. For an entry with no features, the embedding vector for default_id
is returned, or the 0-vector if default_id
is not supplied.
The ids and weights may be multi-dimensional. Embeddings are always aggregated along the last dimension.
If len(embedding_weights) > 1
, each element id
of ids
is partitioned between the elements of embedding_weights
according to the "div" partition strategy, which means we assign ids to partitions in a contiguous manner. For instance, 13 ids are split across 5 partitions as: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]
.
If the id space does not evenly divide the number of partitions, each of the first (max_id + 1) % len(embedding_weights)
partitions will be assigned one more id.
Args | |
---|---|
embedding_weights | A single tensor representing the complete embedding tensor, or a list of tensors all of same shape except for the first dimension, representing sharded embedding tensors following "div" partition strategy. |
sparse_ids | SparseTensor of shape [d_0, d_1, ..., d_n] containing the ids. d_0 is typically batch size. |
sparse_weights | SparseTensor of same shape as sparse_ids , containing float weights corresponding to sparse_ids , or None if all weights are be assumed to be 1.0. |
combiner | A string specifying how to combine embedding results for each entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the default. |
default_id | The id to use for an entry with no features. Defaults to 0-vector. |
max_norm | If not None , all embeddings are l2-normalized to max_norm before combining. |
name | A name for this operation (optional). |
Raises | |
---|---|
ValueError | if embedding_weights is empty. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/nn/safe_embedding_lookup_sparse