An ExtensionType that can be batched and unbatched.
Inherits From: ExtensionType
tf.experimental.BatchableExtensionType( *args, **kwargs )
BatchableExtensionType
s can be used with APIs that require batching or unbatching, including Keras
, tf.data.Dataset
, and tf.map_fn
. E.g.:
class Vehicle(BatchableExtensionType): top_speed: tf.Tensor mpg: tf.Tensor batch = Vehicle([120, 150, 80], [30, 40, 12]) tf.map_fn(lambda vehicle: vehicle.top_speed * vehicle.mpg, batch, fn_output_signature=tf.int32).numpy() array([3600, 6000, 960], dtype=int32)
An ExtensionTypeBatchEncoder
is used by these APIs to encode ExtensionType
values. The default encoder assumes that values can be stacked, unstacked, or concatenated by simply stacking, unstacking, or concatenating every nested Tensor
, ExtensionType
, CompositeTensor
, or TensorShape
field. Extension types where this is not the case will need to override __batch_encoder__
with a custom ExtensionTypeBatchEncoder
. See tf.experimental.ExtensionTypeBatchEncoder
for more details.
__eq__
__eq__( other )
Return self==value.
__ne__
__ne__( other )
Return self!=value.
© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/experimental/BatchableExtensionType