Ops and objects returned from a model_fn
and passed to TPUEstimator
.
tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode, predictions=None, loss=None, train_op=None, eval_metrics=None, export_outputs=None, scaffold_fn=None, host_call=None, training_hooks=None, evaluation_hooks=None, prediction_hooks=None )
See EstimatorSpec
for mode
, predictions
, loss
, train_op
, and export_outputs
.
For evaluation, eval_metrics
is a tuple of metric_fn
and tensors
, where metric_fn
runs on CPU to generate metrics and tensors
represents the Tensor
s transferred from TPU system to CPU host and passed to metric_fn
. To be precise, TPU evaluation expects a slightly different signature from the tf.estimator.Estimator
. While EstimatorSpec.eval_metric_ops
expects a dict, TPUEstimatorSpec.eval_metrics
is a tuple of metric_fn
and tensors
. The tensors
could be a list of Tensor
s or dict of names to Tensor
s. The tensors
usually specify the model logits, which are transferred back from TPU system to CPU host. All tensors must have be batch-major, i.e., the batch size is the first dimension. Once all tensors are available at CPU host from all shards, they are concatenated (on CPU) and passed as positional arguments to the metric_fn
if tensors
is list or keyword arguments if tensors
is a dict. metric_fn
takes the tensors
and returns a dict from metric string name to the result of calling a metric function, namely a (metric_tensor, update_op)
tuple. See TPUEstimator
for MNIST example how to specify the eval_metrics
.
scaffold_fn
is a function running on CPU to generate the Scaffold
. This function should not capture any Tensors in model_fn
.
host_call
is a tuple of a function
and a list or dictionary of tensors
to pass to that function and returns a list of Tensors. host_call
currently works for train() and evaluate(). The Tensors returned by the function is executed on the CPU on every step, so there is communication overhead when sending tensors from TPU to CPU. To reduce the overhead, try reducing the size of the tensors. The tensors
are concatenated along their major (batch) dimension, and so must be >= rank 1. The host_call
is useful for writing summaries with tf.contrib.summary.create_file_writer
.
Attributes | |
---|---|
mode | |
predictions | |
loss | |
train_op | |
eval_metrics | |
export_outputs | |
scaffold_fn | |
host_call | |
training_hooks | |
evaluation_hooks | |
prediction_hooks |
as_estimator_spec
as_estimator_spec()
Creates an equivalent EstimatorSpec
used by CPU train/eval.
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/compat/v1/estimator/tpu/TPUEstimatorSpec