View source on GitHub |
A mixin for putting Python state in an object-based checkpoint.
This is an abstract class which allows extensions to TensorFlow's object-based checkpointing (see tf.train.Checkpoint
). For example a wrapper for NumPy arrays:
import io import numpy class NumpyWrapper(tf.train.experimental.PythonState): def __init__(self, array): self.array = array def serialize(self): string_file = io.BytesIO() try: numpy.save(string_file, self.array, allow_pickle=False) serialized = string_file.getvalue() finally: string_file.close() return serialized def deserialize(self, string_value): string_file = io.BytesIO(string_value) try: self.array = numpy.load(string_file, allow_pickle=False) finally: string_file.close()
Instances of NumpyWrapper
are checkpointable objects, and will be saved and restored from checkpoints along with TensorFlow state like variables.
root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.]))) save_path = root.save(prefix) root.numpy.array *= 2. assert [2.] == root.numpy.array root.restore(save_path) assert [1.] == root.numpy.array
deserialize
@abc.abstractmethod deserialize( string_value )
Callback to deserialize the object.
serialize
@abc.abstractmethod serialize()
Callback to serialize the object. Returns a string.
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/train/experimental/PythonState