Interface for listeners that take action before or after checkpoint save.
CheckpointSaverListener
triggers only in steps when CheckpointSaverHook
is triggered, and provides callbacks at the following points:
Saver.save()
Saver.save()
To use a listener, implement a class and pass the listener to a CheckpointSaverHook
, as in this example:
class ExampleCheckpointSaverListener(CheckpointSaverListener): def begin(self): # You can add ops to the graph here. print('Starting the session.') self.your_tensor = ... def before_save(self, session, global_step_value): print('About to write a checkpoint') def after_save(self, session, global_step_value): print('Done writing checkpoint.') if decided_to_stop_training(): return True def end(self, session, global_step_value): print('Done with the session.') ... listener = ExampleCheckpointSaverListener() saver_hook = tf.estimator.CheckpointSaverHook( checkpoint_dir, listeners=[listener]) with tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): ...
A CheckpointSaverListener
may simply take some action after every checkpoint save. It is also possible for the listener to use its own schedule to act less frequently, e.g. based on global_step_value. In this case, implementors should implement the end()
method to handle actions related to the last checkpoint save. But the listener should not act twice if after_save()
already handled this last checkpoint save.
A CheckpointSaverListener
can request training to be stopped, by returning True in after_save
. Please note that, in replicated distributed training setting, only chief
should use this behavior. Otherwise each worker will do their own evaluation, which may be wasteful of resources.
after_save
after_save( session, global_step_value )
before_save
before_save( session, global_step_value )
begin
begin()
end
end( session, global_step_value )
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/estimator/CheckpointSaverListener