Decorator to override default implementation for binary elementwise assert APIs.
tf.experimental.dispatch_for_binary_elementwise_assert_apis(
x_type, y_type
)
The decorated function (known as the "elementwise assert handler") overrides the default implementation for any binary elementwise assert API whenever the value for the first two arguments (typically named x and y) match the specified type annotations. The handler is called with two arguments:
elementwise_assert_handler(assert_func, x, y)
Where x and y are the first two arguments to the binary elementwise assert operation, and assert_func is a TensorFlow function that takes two parameters and performs the elementwise assert operation (e.g., tf.debugging.assert_equal).
The following example shows how this decorator can be used to update all binary elementwise assert operations to handle a MaskedTensor type:
class MaskedTensor(tf.experimental.ExtensionType): values: tf.Tensor mask: tf.Tensor @dispatch_for_binary_elementwise_assert_apis(MaskedTensor, MaskedTensor) def binary_elementwise_assert_api_handler(assert_func, x, y): merged_mask = tf.logical_and(x.mask, y.mask) selected_x_values = tf.boolean_mask(x.values, merged_mask) selected_y_values = tf.boolean_mask(y.values, merged_mask) assert_func(selected_x_values, selected_y_values) a = MaskedTensor([1, 1, 0, 1, 1], [False, False, True, True, True]) b = MaskedTensor([2, 2, 0, 2, 2], [True, True, True, False, False]) tf.debugging.assert_equal(a, b) # assert passed; no exception was thrown
a = MaskedTensor([1, 1, 1, 1, 1], [True, True, True, True, True]) b = MaskedTensor([0, 0, 0, 0, 2], [True, True, True, True, True]) tf.debugging.assert_greater(a, b) Traceback (most recent call last): InvalidArgumentError: Condition x > y did not hold.
| Args | |
|---|---|
x_type | A type annotation indicating when the api handler should be called. |
y_type | A type annotation indicating when the api handler should be called. |
| Returns | |
|---|---|
| A decorator. |
The binary elementwise assert APIs are:
<
© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/experimental/dispatch_for_binary_elementwise_assert_apis