View source on GitHub |

Return the elements where `condition`

is `True`

(multiplexing `x`

and `y`

).

tf.where( condition, x=None, y=None, name=None )

This operator has two modes: in one mode both `x`

and `y`

are provided, in another mode neither are provided. `condition`

is always expected to be a `tf.Tensor`

of type `bool`

.

`True`

elementsIf `x`

and `y`

are not provided (both are None):

`tf.where`

will return the indices of `condition`

that are `True`

, in the form of a 2-D tensor with shape (n, d). (Where n is the number of matching indices in `condition`

, and d is the number of dimensions in `condition`

).

Indices are output in row-major order.

tf.where([True, False, False, True]) <tf.Tensor: shape=(2, 1), dtype=int64, numpy= array([[0], [3]])>

tf.where([[True, False], [False, True]]) <tf.Tensor: shape=(2, 2), dtype=int64, numpy= array([[0, 0], [1, 1]])>

tf.where([[[True, False], [False, True], [True, True]]]) <tf.Tensor: shape=(4, 3), dtype=int64, numpy= array([[0, 0, 0], [0, 1, 1], [0, 2, 0], [0, 2, 1]])>

`x`

and `y`

If `x`

and `y`

are provided (both have non-None values):

`tf.where`

will choose an output shape from the shapes of `condition`

, `x`

, and `y`

that all three shapes are broadcastable to.

The `condition`

tensor acts as a mask that chooses whether the corresponding element / row in the output should be taken from `x`

(if the element in `condition`

is True) or `y`

(if it is false).

tf.where([True, False, False, True], [1,2,3,4], [100,200,300,400]) <tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 200, 300, 4], dtype=int32)> tf.where([True, False, False, True], [1,2,3,4], [100]) <tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 100, 100, 4], dtype=int32)> tf.where([True, False, False, True], [1,2,3,4], 100) <tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 100, 100, 4], dtype=int32)> tf.where([True, False, False, True], 1, 100) <tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 100, 100, 1], dtype=int32)>

tf.where(True, [1,2,3,4], 100) <tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4], dtype=int32)> tf.where(False, [1,2,3,4], 100) <tf.Tensor: shape=(4,), dtype=int32, numpy=array([100, 100, 100, 100], dtype=int32)>

Note that if the gradient of either branch of the tf.where generates a NaN, then the gradient of the entire tf.where will be NaN. A workaround is to use an inner tf.where to ensure the function has no asymptote, and to avoid computing a value whose gradient is NaN by replacing dangerous inputs with safe inputs.

Instead of this,

y = tf.constant(-1, dtype=tf.float32) tf.where(y > 0, tf.sqrt(y), y) <tf.Tensor: shape=(), dtype=float32, numpy=-1.0>

Use this

tf.where(y > 0, tf.sqrt(tf.where(y > 0, y, 1)), y) <tf.Tensor: shape=(), dtype=float32, numpy=-1.0>

Args | |
---|---|

`condition` | A `tf.Tensor` of type `bool` |

`x` | If provided, a Tensor which is of the same type as `y` , and has a shape broadcastable with `condition` and `y` . |

`y` | If provided, a Tensor which is of the same type as `x` , and has a shape broadcastable with `condition` and `x` . |

`name` | A name of the operation (optional). |

Returns | |
---|---|

If `x` and `y` are provided: A `Tensor` with the same type as `x` and `y` , and shape that is broadcast from `condition` , `x` , and `y` . Otherwise, a `Tensor` with shape `(num_true, dim_size(condition))` . |

Raises | |
---|---|

`ValueError` | When exactly one of `x` or `y` is non-None, or the shapes are not all broadcastable. |

© 2020 The TensorFlow Authors. All rights reserved.

Licensed under the Creative Commons Attribution License 3.0.

Code samples licensed under the Apache 2.0 License.

https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/where