View source on GitHub |
Returns the indices of non-zero elements, or multiplexes x
and y
.
tf.where( condition, x=None, y=None, name=None )
This operation has two modes:
condition
is provided the result is an int64
tensor where each row is the index of a non-zero element of condition
. The result's shape is [tf.math.count_nonzero(condition), tf.rank(condition)]
.x
and y
- When both x
and y
are provided the result has the shape of x
, y
, and condition
broadcast together. The result is taken from x
where condition
is non-zero or y
where condition
is zero.Note: In this modecondition
can have a dtype ofbool
or any numeric dtype.
If x
and y
are not provided (both are None):
tf.where
will return the indices of condition
that are non-zero, in the form of a 2-D tensor with shape [n, d]
, where n
is the number of non-zero elements in condition
(tf.count_nonzero(condition)
), and d
is the number of axes of condition
(tf.rank(condition)
).
Indices are output in row-major order. The condition
can have a dtype
of tf.bool
, or any numeric dtype
.
Here condition
is a 1-axis bool
tensor with 2 True
values. The result has a shape of [2,1]
tf.where([True, False, False, True]).numpy() array([[0], [3]])
Here condition
is a 2-axis integer tensor, with 3 non-zero values. The result has a shape of [3, 2]
.
tf.where([[1, 0, 0], [1, 0, 1]]).numpy() array([[0, 0], [1, 0], [1, 2]])
Here condition
is a 3-axis float tensor, with 5 non-zero values. The output shape is [5, 3]
.
float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]], [[0, 0], [0, 0], [99, 0]]] tf.where(float_tensor).numpy() array([[0, 0, 0], [0, 1, 1], [0, 2, 0], [0, 2, 1], [1, 2, 0]])
These indices are the same that tf.sparse.SparseTensor
would use to represent the condition tensor:
sparse = tf.sparse.from_dense(float_tensor) sparse.indices.numpy() array([[0, 0, 0], [0, 1, 1], [0, 2, 0], [0, 2, 1], [1, 2, 0]])
A complex number is considered non-zero if either the real or imaginary component is non-zero:
tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy() array([[1], [2], [3]])
x
and y
Note: In this modecondition
must have a dtype ofbool
.
If x
and y
are also provided (both have non-None values) the condition
tensor acts as a mask that chooses whether the corresponding element / row in the output should be taken from x
(if the element in condition
is True
) or y
(if it is False
).
The shape of the result is formed by broadcasting together the shapes of condition
, x
, and y
.
When all three inputs have the same size, each is handled element-wise.
tf.where([True, False, False, True], [1, 2, 3, 4], [100, 200, 300, 400]).numpy() array([ 1, 200, 300, 4], dtype=int32)
There are two main rules for broadcasting:
A length-1 vector is streched to match the other vectors:
tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy() array([ 1, 100, 100, 4], dtype=int32)
A scalar is expanded to match the other arguments:
tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy() array([[ 1, 100], [100, 4]], dtype=int32) tf.where([[True, False], [False, True]], 1, 100).numpy() array([[ 1, 100], [100, 1]], dtype=int32)
A scalar condition
returns the complete x
or y
tensor, with broadcasting applied.
tf.where(True, [1, 2, 3, 4], 100).numpy() array([1, 2, 3, 4], dtype=int32) tf.where(False, [1, 2, 3, 4], 100).numpy() array([100, 100, 100, 100], dtype=int32)
For a non-trivial example of broadcasting, here condition
has a shape of [3]
, x
has a shape of [3,3]
, and y
has a shape of [3,1]
. Broadcasting first expands the shape of condition
to [1,3]
. The final broadcast shape is [3,3]
. condition
will select columns from x
and y
. Since y
only has one column, all columns from y
will be identical.
tf.where([True, False, True], x=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], y=[[100], [200], [300]] ).numpy() array([[ 1, 100, 3], [ 4, 200, 6], [ 7, 300, 9]], dtype=int32)
Note that if the gradient of either branch of the tf.where
generates a NaN
, then the gradient of the entire tf.where
will be NaN
. This is because the gradient calculation for tf.where
combines the two branches, for performance reasons.
A workaround is to use an inner tf.where
to ensure the function has no asymptote, and to avoid computing a value whose gradient is NaN
by replacing dangerous inputs with safe inputs.
Instead of this,
x = tf.constant(0., dtype=tf.float32) with tf.GradientTape() as tape: tape.watch(x) y = tf.where(x < 1., 0., 1. / x) print(tape.gradient(y, x)) tf.Tensor(nan, shape=(), dtype=float32)
Although, the 1. / x
values are never used, its gradient is a NaN
when x = 0
. Instead, we should guard that with another tf.where
x = tf.constant(0., dtype=tf.float32) with tf.GradientTape() as tape: tape.watch(x) safe_x = tf.where(tf.equal(x, 0.), 1., x) y = tf.where(x < 1., 0., 1. / safe_x) print(tape.gradient(y, x)) tf.Tensor(0.0, shape=(), dtype=float32)
tf.sparse
- The indices returned by the first form of tf.where
can be useful in tf.sparse.SparseTensor
objects.tf.gather_nd
, tf.scatter_nd
, and related ops - Given the list of indices returned from tf.where
the scatter
and gather
family of ops can be used fetch values or insert values at those indices.tf.strings.length
- tf.string
is not an allowed dtype for the condition
. Use the string length instead.Args | |
---|---|
condition | A tf.Tensor of dtype bool, or any numeric dtype. condition must have dtype bool when x and y are provided. |
x | If provided, a Tensor which is of the same type as y , and has a shape broadcastable with condition and y . |
y | If provided, a Tensor which is of the same type as x , and has a shape broadcastable with condition and x . |
name | A name of the operation (optional). |
Returns | |
---|---|
If x and y are provided: A Tensor with the same type as x and y , and shape that is broadcast from condition , x , and y . Otherwise, a Tensor with shape [tf.math.count_nonzero(condition), tf.rank(condition)] . |
Raises | |
---|---|
ValueError | When exactly one of x or y is non-None, or the shapes are not all broadcastable. |
© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/where