Gather slices from params
into a Tensor with shape specified by indices
.
tf.compat.v1.gather_nd( params, indices, name=None, batch_dims=0 )
indices
is a Tensor
of indices into params
. The index vectors are arranged along the last axis of indices
.
This is similar to tf.gather
, in which indices
defines slices into the first dimension of params
. In tf.gather_nd
, indices
defines slices into the first N
dimensions of params
, where N = indices.shape[-1]
.
In the simplest case the vectors in indices
index the full rank of params
:
tf.gather_nd( indices=[[0, 0], [1, 1]], params = [['a', 'b'], ['c', 'd']]).numpy() array([b'a', b'd'], dtype=object)
In this case the result has 1-axis fewer than indices
, and each index vector is replaced by the scalar indexed from params
.
In this case the shape relationship is:
index_depth = indices.shape[-1] assert index_depth == params.shape.rank result_shape = indices.shape[:-1]
If indices
has a rank of K
, it is helpful to think indices
as a (K-1)-dimensional tensor of indices into params
.
If the index vectors do not index the full rank of params
then each location in the result contains a slice of params. This example collects rows from a matrix:
tf.gather_nd( indices = [[1], [0]], params = [['a', 'b', 'c'], ['d', 'e', 'f']]).numpy() array([[b'd', b'e', b'f'], [b'a', b'b', b'c']], dtype=object)
Here indices
contains [2]
index vectors, each with a length of 1
. The index vectors each refer to rows of the params
matrix. Each row has a shape of [3]
so the output shape is [2, 3]
.
In this case, the relationship between the shapes is:
index_depth = indices.shape[-1] outer_shape = indices.shape[:-1] assert index_depth <= params.shape.rank inner_shape = params.shape[index_depth:] output_shape = outer_shape + inner_shape
It is helpful to think of the results in this case as tensors-of-tensors. The shape of the outer tensor is set by the leading dimensions of indices
. While the shape of the inner tensors is the shape of a single slice.
Additionally both params
and indices
can have M
leading batch dimensions that exactly match. In this case batch_dims
must be set to M
.
For example, to collect one row from each of a batch of matrices you could set the leading elements of the index vectors to be their location in the batch:
tf.gather_nd( indices = [[0, 1], [1, 0], [2, 4], [3, 2], [4, 1]], params=tf.zeros([5, 7, 3])).shape.as_list() [5, 3]
The batch_dims
argument lets you omit those leading location dimensions from the index:
tf.gather_nd( batch_dims=1, indices = [[1], [0], [4], [2], [1]], params=tf.zeros([5, 7, 3])).shape.as_list() [5, 3]
This is equivalent to caling a separate gather_nd
for each location in the batch dimensions.
params=tf.zeros([5, 7, 3]) indices=tf.zeros([5, 1]) batch_dims = 1 index_depth = indices.shape[-1] batch_shape = indices.shape[:batch_dims] assert params.shape[:batch_dims] == batch_shape outer_shape = indices.shape[batch_dims:-1] assert index_depth <= params.shape.rank inner_shape = params.shape[batch_dims + index_depth:] output_shape = batch_shape + outer_shape + inner_shape output_shape.as_list() [5, 3]
Indexing into a 3-tensor:
tf.gather_nd( indices = [[1]], params = [[['a0', 'b0'], ['c0', 'd0']], [['a1', 'b1'], ['c1', 'd1']]]).numpy() array([[[b'a1', b'b1'], [b'c1', b'd1']]], dtype=object)
tf.gather_nd( indices = [[0, 1], [1, 0]], params = [[['a0', 'b0'], ['c0', 'd0']], [['a1', 'b1'], ['c1', 'd1']]]).numpy() array([[b'c0', b'd0'], [b'a1', b'b1']], dtype=object)
tf.gather_nd( indices = [[0, 0, 1], [1, 0, 1]], params = [[['a0', 'b0'], ['c0', 'd0']], [['a1', 'b1'], ['c1', 'd1']]]).numpy() array([b'b0', b'b1'], dtype=object)
The examples below are for the case when only indices have leading extra dimensions. If both 'params' and 'indices' have leading batch dimensions, use the 'batch_dims' parameter to run gather_nd in batch mode.
Batched indexing into a matrix:
tf.gather_nd( indices = [[[0, 0]], [[0, 1]]], params = [['a', 'b'], ['c', 'd']]).numpy() array([[b'a'], [b'b']], dtype=object)
Batched slice indexing into a matrix:
tf.gather_nd( indices = [[[1]], [[0]]], params = [['a', 'b'], ['c', 'd']]).numpy() array([[[b'c', b'd']], [[b'a', b'b']]], dtype=object)
Batched indexing into a 3-tensor:
tf.gather_nd( indices = [[[1]], [[0]]], params = [[['a0', 'b0'], ['c0', 'd0']], [['a1', 'b1'], ['c1', 'd1']]]).numpy() array([[[[b'a1', b'b1'], [b'c1', b'd1']]], [[[b'a0', b'b0'], [b'c0', b'd0']]]], dtype=object)
tf.gather_nd( indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]], params = [[['a0', 'b0'], ['c0', 'd0']], [['a1', 'b1'], ['c1', 'd1']]]).numpy() array([[[b'c0', b'd0'], [b'a1', b'b1']], [[b'a0', b'b0'], [b'c1', b'd1']]], dtype=object)
tf.gather_nd( indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]], params = [[['a0', 'b0'], ['c0', 'd0']], [['a1', 'b1'], ['c1', 'd1']]]).numpy() array([[b'b0', b'b1'], [b'd0', b'c1']], dtype=object)
Examples with batched 'params' and 'indices':
tf.gather_nd( batch_dims = 1, indices = [[1], [0]], params = [[['a0', 'b0'], ['c0', 'd0']], [['a1', 'b1'], ['c1', 'd1']]]).numpy() array([[b'c0', b'd0'], [b'a1', b'b1']], dtype=object)
tf.gather_nd( batch_dims = 1, indices = [[[1]], [[0]]], params = [[['a0', 'b0'], ['c0', 'd0']], [['a1', 'b1'], ['c1', 'd1']]]).numpy() array([[[b'c0', b'd0']], [[b'a1', b'b1']]], dtype=object)
tf.gather_nd( batch_dims = 1, indices = [[[1, 0]], [[0, 1]]], params = [[['a0', 'b0'], ['c0', 'd0']], [['a1', 'b1'], ['c1', 'd1']]]).numpy() array([[b'c0'], [b'b1']], dtype=object)
See also tf.gather
.
Args | |
---|---|
params | A Tensor . The tensor from which to gather values. |
indices | A Tensor . Must be one of the following types: int32 , int64 . Index tensor. |
name | A name for the operation (optional). |
batch_dims | An integer or a scalar 'Tensor'. The number of batch dimensions. |
Returns | |
---|---|
A Tensor . Has the same type as params . |
© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/compat/v1/gather_nd