| View source on GitHub | 
Updates the shape of a tensor and checks at runtime that the shape holds.
tf.ensure_shape(
    x, shape, name=None
)
  When executed, this operation asserts that the input tensor x's shape is compatible with the shape argument. See tf.TensorShape.is_compatible_with for details.
x = tf.constant([[1, 2, 3],
                 [4, 5, 6]])
x = tf.ensure_shape(x, [2, 3])
 Use None for unknown dimensions:
x = tf.ensure_shape(x, [None, 3]) x = tf.ensure_shape(x, [2, None])
If the tensor's shape is not compatible with the shape argument, an error is raised:
x = tf.ensure_shape(x, [5]) Traceback (most recent call last): tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not compatible with expected shape [5]. [Op:EnsureShape]
During graph construction (typically tracing a tf.function), tf.ensure_shape updates the static-shape of the result tensor by merging the two shapes. See tf.TensorShape.merge_with for details.
This is most useful when you know a shape that can't be determined statically by TensorFlow.
The following trivial tf.function prints the input tensor's static-shape before and after ensure_shape is applied.
@tf.function
def f(tensor):
  print("Static-shape before:", tensor.shape)
  tensor = tf.ensure_shape(tensor, [None, 3])
  print("Static-shape after:", tensor.shape)
  return tensor
 This lets you see the effect of tf.ensure_shape when the function is traced:
>>> cf = f.get_concrete_function(tf.TensorSpec([None, None])) Static-shape before: (None, None) Static-shape after: (None, 3)
cf(tf.zeros([3, 3])) # Passes cf(tf.constant([1, 2, 3])) # fails Traceback (most recent call last): InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].
The above example raises tf.errors.InvalidArgumentError, because x's shape, (3,), is not compatible with the shape argument, (None, 3)
Inside a tf.function or v1.Graph context it checks both the buildtime and runtime shapes. This is stricter than tf.Tensor.set_shape which only checks the buildtime shape.
Note: This differs fromtf.Tensor.set_shapein that it sets the static shape of the resulting tensor and enforces it at runtime, raising an error if the tensor's runtime shape is incompatible with the specified shape.tf.Tensor.set_shapesets the static shape of the tensor without enforcing it at runtime, which may result in inconsistencies between the statically-known shape of tensors and the runtime value of tensors.
For example, of loading images of a known size:
@tf.function
def decode_image(png):
  image = tf.image.decode_png(png, channels=3)
  # the `print` executes during tracing.
  print("Initial shape: ", image.shape)
  image = tf.ensure_shape(image,[28, 28, 3])
  print("Final shape: ", image.shape)
  return image
 When tracing a function, no ops are being executed, shapes may be unknown. See the Concrete Functions Guide for details.
concrete_decode = decode_image.get_concrete_function(
    tf.TensorSpec([], dtype=tf.string))
Initial shape:  (None, None, 3)
Final shape:  (28, 28, 3)
 image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) image = tf.cast(image,tf.uint8) png = tf.image.encode_png(image) image2 = concrete_decode(png) print(image2.shape) (28, 28, 3)
image = tf.concat([image,image], axis=0) print(image.shape) (56, 28, 3) png = tf.image.encode_png(image) image2 = concrete_decode(png) Traceback (most recent call last): tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not compatible with expected shape [28,28,3].
@tf.function
def bad_decode_image(png):
  image = tf.image.decode_png(png, channels=3)
  # the `print` executes during tracing.
  print("Initial shape: ", image.shape)
  # BAD: forgot to use the returned tensor.
  tf.ensure_shape(image,[28, 28, 3])
  print("Final shape: ", image.shape)
  return image
 image = bad_decode_image(png) Initial shape: (None, None, 3) Final shape: (None, None, 3) print(image.shape) (56, 28, 3)
| Args | |
|---|---|
| x | A Tensor. | 
| shape | A TensorShaperepresenting the shape of this tensor, aTensorShapeProto, a list, a tuple, or None. | 
| name | A name for this operation (optional). Defaults to "EnsureShape". | 
| Returns | |
|---|---|
| A Tensor. Has the same type and contents asx. | 
| Raises | |
|---|---|
| tf.errors.InvalidArgumentError | If shapeis incompatible with the shape ofx. | 
    © 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
    https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/ensure_shape