Apply standard lookup ops with tf.tpu.experimental.embedding
configs.
tf.tpu.experimental.embedding.serving_embedding_lookup( inputs, weights, tables, feature_config )
This function is a utility which allows using the tf.tpu.experimental.embedding
config objects with standard lookup functions. This can be used when exporting a model which uses tf.tpu.experimental.embedding.TPUEmbedding
for serving on CPU. In particular tf.tpu.experimental.embedding.TPUEmbedding
only supports lookups on TPUs and should not be part of your serving graph.
Note that TPU specific options (such as max_sequence_length
) in the configuration objects will be ignored.
In the following example we take take a trained model (see the documentation for tf.tpu.experimental.embedding.TPUEmbedding
for the context) and create a saved model with a serving function that will perform the embedding lookup and pass the results to your model:
model = model_fn(...) embedding = tf.tpu.experimental.embedding.TPUEmbedding( feature_config=feature_config, batch_size=1024, optimizer=tf.tpu.experimental.embedding.SGD(0.1)) checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) checkpoint.restore(...) @tf.function(input_signature=[{'feature_one': tf.TensorSpec(...), 'feature_two': tf.TensorSpec(...), 'feature_three': tf.TensorSpec(...)}]) def serve_tensors(embedding_featurese): embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup( embedding_features, None, embedding.embedding_tables, feature_config) return model(embedded_features) model.embedding_api = embedding tf.saved_model.save(model, export_dir=..., signatures={'serving_default': serve_tensors})
Note: Its important to assign the embedding api object to a member of your model astf.saved_model.save
only supports saving variables oneTrackable
object. Since the model's weights are inmodel
and the embedding table are managed byembedding
, we assignembedding
to and attribute ofmodel
so that tf.saved_model.save can find the embedding variables.
Note: The sameserve_tensors
function andtf.saved_model.save
call will work directly from training.
Args | |
---|---|
inputs | a nested structure of Tensors, SparseTensors or RaggedTensors. |
weights | a nested structure of Tensors, SparseTensors or RaggedTensors or None for no weights. If not None, structure must match that of inputs, but entries are allowed to be None. |
tables | a dict of mapping TableConfig objects to Variables. |
feature_config | a nested structure of FeatureConfig objects with the same structure as inputs. |
Returns | |
---|---|
A nested structure of Tensors with the same structure as inputs. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/tpu/experimental/embedding/serving_embedding_lookup