View source on GitHub |

Updates the shape of a tensor and checks at runtime that the shape holds.

tf.ensure_shape( x, shape, name=None )

@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def f(tensor): return tf.ensure_shape(tensor, [3, 3]) f(tf.zeros([3, 3])) # Passes <tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=float32)> f([1, 2, 3]) # fails Traceback (most recent call last): InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].

The above example raises `tf.errors.InvalidArgumentError`

, because the shape (3,) is not compatible with the shape (None, 3, 3)

With eager execution this is a shape assertion, that returns the input:

x = tf.constant([1,2,3]) print(x.shape) (3,) x = tf.ensure_shape(x, [3]) x = tf.ensure_shape(x, [5]) Traceback (most recent call last): tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not compatible with expected shape [5]. [Op:EnsureShape]

Inside a `tf.function`

or `v1.Graph`

context it checks both the buildtime and runtime shapes. This is stricter than `tf.Tensor.set_shape`

which only checks the buildtime shape.

Note:This differs from`tf.Tensor.set_shape`

in that it sets the static shape of the resulting tensor and enforces it at runtime, raising an error if the tensor's runtime shape is incompatible with the specified shape.`tf.Tensor.set_shape`

sets the static shape of the tensor without enforcing it at runtime, which may result in inconsistencies between the statically-known shape of tensors and the runtime value of tensors.

For example, of loading images of a known size:

@tf.function def decode_image(png): image = tf.image.decode_png(png, channels=3) # the `print` executes during tracing. print("Initial shape: ", image.shape) image = tf.ensure_shape(image,[28, 28, 3]) print("Final shape: ", image.shape) return image

When tracing a function, no ops are being executed, shapes may be unknown. See the Concrete Functions Guide for details.

concrete_decode = decode_image.get_concrete_function( tf.TensorSpec([], dtype=tf.string)) Initial shape: (None, None, 3) Final shape: (28, 28, 3)

image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) image = tf.cast(image,tf.uint8) png = tf.image.encode_png(image) image2 = concrete_decode(png) print(image2.shape) (28, 28, 3)

image = tf.concat([image,image], axis=0) print(image.shape) (56, 28, 3) png = tf.image.encode_png(image) image2 = concrete_decode(png) Traceback (most recent call last): tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not compatible with expected shape [28,28,3].

@tf.function def bad_decode_image(png): image = tf.image.decode_png(png, channels=3) # the `print` executes during tracing. print("Initial shape: ", image.shape) # BAD: forgot to use the returned tensor. tf.ensure_shape(image,[28, 28, 3]) print("Final shape: ", image.shape) return image

image = bad_decode_image(png) Initial shape: (None, None, 3) Final shape: (None, None, 3) print(image.shape) (56, 28, 3)

Args | |
---|---|

`x` | A `Tensor` . |

`shape` | A `TensorShape` representing the shape of this tensor, a `TensorShapeProto` , a list, a tuple, or None. |

`name` | A name for this operation (optional). Defaults to "EnsureShape". |

Returns | |
---|---|

A `Tensor` . Has the same type and contents as `x` . |

Raises | |
---|---|

`tf.errors.InvalidArgumentError` | If `shape` is incompatible with the shape of `x` . |

© 2020 The TensorFlow Authors. All rights reserved.

Licensed under the Creative Commons Attribution License 3.0.

Code samples licensed under the Apache 2.0 License.

https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/ensure_shape