W3cubDocs

/TensorFlow 2.3

tf.estimator.experimental.stop_if_higher_hook

View source on GitHub

Creates hook to stop if the given metric is higher than the threshold.

Usage example:

estimator = ...
# Hook to stop training if accuracy becomes higher than 0.9.
hook = early_stopping.stop_if_higher_hook(estimator, "accuracy", 0.9)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)

Caveat: Current implementation supports early-stopping both training and evaluation in local mode. In distributed mode, training can be stopped but evaluation (where it's a separate job) will indefinitely wait for new model checkpoints to evaluate, so you will need other means to detect and stop it. Early-stopping evaluation in distributed mode requires changes in train_and_evaluate API and will be addressed in a future revision.

Args
estimator A tf.estimator.Estimator instance.
metric_name str, metric to track. "loss", "accuracy", etc.
threshold Numeric threshold for the given metric.
eval_dir If set, directory containing summary files with eval metrics. By default, estimator.eval_dir() will be used.
min_steps int, stop is never requested if global step is less than this value. Defaults to 0.
run_every_secs If specified, calls should_stop_fn at an interval of run_every_secs seconds. Defaults to 60 seconds. Either this or run_every_steps must be set.
run_every_steps If specified, calls should_stop_fn every run_every_steps steps. Either this or run_every_secs must be set.
Returns
An early-stopping hook of type SessionRunHook that periodically checks if the given metric is higher than specified threshold and initiates early stopping if true.

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/estimator/experimental/stop_if_higher_hook