EstimatorSpec
Defined in tensorflow/python/estimator/model_fn.py.
Ops and objects returned from a model_fn and passed to an Estimator.
EstimatorSpec fully defines the model to be run by an Estimator.
eval_metric_opsAlias for field number 4
evaluation_hooksAlias for field number 9
export_outputsAlias for field number 5
lossAlias for field number 2
modeAlias for field number 0
prediction_hooksAlias for field number 10
predictionsAlias for field number 1
scaffoldAlias for field number 8
train_opAlias for field number 3
training_chief_hooksAlias for field number 6
training_hooksAlias for field number 7
__new__@staticmethod
__new__(
cls,
mode,
predictions=None,
loss=None,
train_op=None,
eval_metric_ops=None,
export_outputs=None,
training_chief_hooks=None,
training_hooks=None,
scaffold=None,
evaluation_hooks=None,
prediction_hooks=None
)
Creates a validated EstimatorSpec instance.
Depending on the value of mode, different arguments are required. Namely
mode == ModeKeys.TRAIN: required fields are loss and train_op.mode == ModeKeys.EVAL: required field is loss.mode == ModeKeys.PREDICT: required fields are predictions.model_fn can populate all arguments independent of mode. In this case, some arguments will be ignored by an Estimator. E.g. train_op will be ignored in eval and infer modes. Example:
def my_model_fn(mode, features, labels):
predictions = ...
loss = ...
train_op = ...
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op)
Alternatively, model_fn can just populate the arguments appropriate to the given mode. Example:
def my_model_fn(mode, features, labels):
if (mode == tf.estimator.ModeKeys.TRAIN or
mode == tf.estimator.ModeKeys.EVAL):
loss = ...
else:
loss = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = ...
else:
train_op = None
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = ...
else:
predictions = None
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op)
mode: A ModeKeys. Specifies if this is training, evaluation or prediction.predictions: Predictions Tensor or dict of Tensor.loss: Training loss Tensor. Must be either scalar, or with shape [1].train_op: Op for the training step.eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.export_outputs: Describes the output signatures to be exported to SavedModel and used during serving. A dict {name: output} where:ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.prediction_hooks: Iterable of tf.train.SessionRunHook objects to run during predictions.A validated EstimatorSpec object.
ValueError: If validation fails.TypeError: If any of the arguments is not the expected type.
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec