View source on GitHub |
Updates the shape of a tensor and checks at runtime that the shape holds.
tf.ensure_shape( x, shape, name=None )
With eager execution this is a shape assertion, that returns the input:
x = tf.constant([1,2,3]) print(x.shape) (3,) x = tf.ensure_shape(x, [3]) x = tf.ensure_shape(x, [5]) Traceback (most recent call last): tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not compatible with expected shape [5]. [Op:EnsureShape]
Inside a tf.function
or v1.Graph
context it checks both the buildtime and runtime shapes. This is stricter than tf.Tensor.set_shape
which only checks the buildtime shape.
Note: This differs fromtf.Tensor.set_shape
in that it sets the static shape of the resulting tensor and enforces it at runtime, raising an error if the tensor's runtime shape is incompatible with the specified shape.tf.Tensor.set_shape
sets the static shape of the tensor without enforcing it at runtime, which may result in inconsistencies between the statically-known shape of tensors and the runtime value of tensors.
For example, of loading images of a known size:
@tf.function def decode_image(png): image = tf.image.decode_png(png, channels=3) # the `print` executes during tracing. print("Initial shape: ", image.shape) image = tf.ensure_shape(image,[28, 28, 3]) print("Final shape: ", image.shape) return image
When tracing a function, no ops are being executed, shapes may be unknown. See the Concrete Functions Guide for details.
concrete_decode = decode_image.get_concrete_function( tf.TensorSpec([], dtype=tf.string)) Initial shape: (None, None, 3) Final shape: (28, 28, 3)
image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) image = tf.cast(image,tf.uint8) png = tf.image.encode_png(image) image2 = concrete_decode(png) print(image2.shape) (28, 28, 3)
image = tf.concat([image,image], axis=0) print(image.shape) (56, 28, 3) png = tf.image.encode_png(image) image2 = concrete_decode(png) Traceback (most recent call last): tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not compatible with expected shape [28,28,3].
@tf.function def bad_decode_image(png): image = tf.image.decode_png(png, channels=3) # the `print` executes during tracing. print("Initial shape: ", image.shape) # BAD: forgot to use the returned tensor. tf.ensure_shape(image,[28, 28, 3]) print("Final shape: ", image.shape) return image
image = bad_decode_image(png) Initial shape: (None, None, 3) Final shape: (None, None, 3) print(image.shape) (56, 28, 3)
Args | |
---|---|
x | A Tensor . |
shape | A TensorShape representing the shape of this tensor, a TensorShapeProto , a list, a tuple, or None. |
name | A name for this operation (optional). Defaults to "EnsureShape". |
Returns | |
---|---|
A Tensor . Has the same type and contents as x . At runtime, raises a tf.errors.InvalidArgumentError if shape is incompatible with the shape of x . |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/ensure_shape