View source on GitHub |
"Scatter updates
into an existing tensor according to indices
.
tf.tensor_scatter_nd_update( tensor, indices, updates, name=None )
This operation creates a new tensor by applying sparse updates
to the input tensor
. This is similar to an index assignment.
# Not implemented: tensors cannot be updated inplace. tensor[indices] = updates
If an out of bound index is found on CPU, an error is returned.
- If an out of bound index is found, the index is ignored.
- The order in which updates are applied is nondeterministic, so the output will be nondeterministic if
indices
contains duplicates.
This operation is very similar to tf.scatter_nd
, except that the updates are scattered onto an existing tensor (as opposed to a zero-tensor). If the memory for the existing tensor cannot be re-used, a copy is made and updated.
indices
is an integer tensor - the indices to update in tensor
.indices
has at least two axes, the last axis is the depth of the index vectors.indices
there is a corresponding entry in updates
.tensor
, then the index vectors each point to scalars in tensor
and each update is a scalar.tensor
, then the index vectors each point to slices of tensor
and shape of the updates must match that slice.Overall this leads to the following shape constraints:
assert tf.rank(indices) >= 2 index_depth = indices.shape[-1] batch_shape = indices.shape[:-1] assert index_depth <= tf.rank(tensor) outer_shape = tensor.shape[:index_depth] inner_shape = tensor.shape[index_depth:] assert updates.shape == batch_shape + inner_shape
Typical usage is often much simpler than this general form, and it can be better understood starting with simple examples:
The simplest usage inserts scalar elements into a tensor by index. In this case, the index_depth
must equal the rank of the input tensor
, slice each column of indices
is an index into an axis of the input tensor
.
In this simplest case the shape constraints are:
num_updates, index_depth = indices.shape.as_list() assert updates.shape == [num_updates] assert index_depth == tf.rank(tensor)`
For example, to insert 4 scattered elements in a rank-1 tensor with 8 elements.
This scatter operation would look like this:
tensor = [0, 0, 0, 0, 0, 0, 0, 0] # tf.rank(tensor) == 1 indices = [[1], [3], [4], [7]] # num_updates == 4, index_depth == 1 updates = [9, 10, 11, 12] # num_updates == 4 print(tf.tensor_scatter_nd_update(tensor, indices, updates)) tf.Tensor([ 0 9 0 10 11 0 0 12], shape=(8,), dtype=int32)
The length (first axis) of updates
must equal the length of the indices
: num_updates
. This is the the number of updates being inserted. Each scalar update is inserted into tensor
at the indexed location.
For a higher rank input tensor
scalar updates can be inserted by using an index_depth
that matches tf.rank(tensor)
:
tensor = [[1, 1], [1, 1], [1, 1]] # tf.rank(tensor) == 2 indices = [[0, 1], [2, 0]] # num_updates == 2, index_depth == 2 updates = [5, 10] # num_updates == 2 print(tf.tensor_scatter_nd_update(tensor, indices, updates)) tf.Tensor( [[ 1 5] [ 1 1] [10 1]], shape=(3, 2), dtype=int32)
When the input tensor
has more than one axis scatter can be used to update entire slices.
In this case it's helpful to think of the input tensor
as being a two level array-of-arrays. The shape of this two level array is split into the outer_shape
and the inner_shape
.
indices
indexes into the outer level of the input tensor (outer_shape
). and replaces the sub-array at that location with the coresponding item from the updates
list. The shape of each update is inner_shape
.
When updating a list of slices the shape constraints are:
num_updates, index_depth = indices.shape.as_list() inner_shape = tensor.shape[:index_depth] outer_shape = tensor.shape[index_depth:] assert updates.shape == [num_updates, inner_shape]
For example, to update rows of a (6, 3)
tensor
:
tensor = tf.zeros([6, 3], dtype=tf.int32)
Use an index depth of one.
indices = tf.constant([[2], [4]]) # num_updates == 2, index_depth == 1 num_updates, index_depth = indices.shape.as_list()
The outer_shape
is 6
, the inner shape is 3
:
outer_shape = tensor.shape[:index_depth] inner_shape = tensor.shape[index_depth:]
2 rows are being indexed so 2 updates
must be supplied. Each update must be shaped to match the inner_shape
.
# num_updates == 2, inner_shape==3 updates = tf.constant([[1, 2, 3], [4, 5, 6]])
Alltogether this gives:
tf.tensor_scatter_nd_update(tensor, indices, updates).numpy() array([[0, 0, 0], [0, 0, 0], [1, 2, 3], [0, 0, 0], [4, 5, 6], [0, 0, 0]], dtype=int32)
A tensor representing a batch of uniformly sized video clips naturally has 5 axes: [batch_size, time, width, height, channels]
.
batch_size, time, width, height, channels = 13,11,7,5,3 video_batch = tf.zeros([batch_size, time, width, height, channels])
To replace a selection of video clips:
index_depth
of 1 (indexing the outer_shape
: [batch_size]
)inner_shape
: [time, width, height, channels]
.To relace the first two clips with ones:
indices = [[0],[1]] new_clips = tf.ones([2, time, width, height, channels]) tf.tensor_scatter_nd_update(video_batch, indices, new_clips)
To replace a selection of frames in the videos:
indices
must have an index_depth
of 2 for the outer_shape
: [batch_size, time]
.updates
must be shaped like a list of images. Each update must have a shape, matching the inner_shape
: [width, height, channels]
.To replace the first frame of the first three video clips:
indices = [[0, 0], [1, 0], [2, 0]] # num_updates=3, index_depth=2 new_images = tf.ones([ # num_updates=3, inner_shape=(width, height, channels) 3, width, height, channels]) tf.tensor_scatter_nd_update(video_batch, indices, new_images)
In simple cases it's convienient to think of indices
and updates
as lists, but this is not a strict requirement. Instead of a flat num_updates
, the indices
and updates
can be folded into a batch_shape
. This batch_shape
is all axes of the indices
, except for the innermost index_depth
axis.
index_depth = indices.shape[-1] batch_shape = indices.shape[:-1]
Note: The one exception is that thebatch_shape
cannot be[]
. You can't update a single index by passing indices with shape[index_depth]
.
updates
must have a matching batch_shape
(the axes before inner_shape
).
assert updates.shape == batch_shape + inner_shape
Note: The result is equivalent to flattening thebatch_shape
axes ofindices
andupdates
. This generalization just avoids the need for reshapes when it is more natural to construct "folded" indices and updates.
With this generalization the full shape constraints are:
assert tf.rank(indices) >= 2 index_depth = indices.shape[-1] batch_shape = indices.shape[:-1] assert index_depth <= tf.rank(tensor) outer_shape = tensor.shape[:index_depth] inner_shape = tensor.shape[index_depth:] assert updates.shape == batch_shape + inner_shape
For example, to draw an X
on a (5,5)
matrix start with these indices:
tensor = tf.zeros([5,5]) indices = tf.constant([ [[0,0], [1,1], [2,2], [3,3], [4,4]], [[0,4], [1,3], [2,2], [3,1], [4,0]], ]) indices.shape.as_list() # batch_shape == [2, 5], index_depth == 2 [2, 5, 2]
Here the indices
do not have a shape of [num_updates, index_depth]
, but a shape of batch_shape+[index_depth]
.
Since the index_depth
is equal to the rank of tensor
:
outer_shape
is (5,5)
inner_shape
is ()
- each update is scalarupdates.shape
is batch_shape + inner_shape == (5,2) + ()
updates = [ [1,1,1,1,1], [1,1,1,1,1], ]
Putting this together gives:
tf.tensor_scatter_nd_update(tensor, indices, updates).numpy() array([[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.], [0., 0., 1., 0., 0.], [0., 1., 0., 1., 0.], [1., 0., 0., 0., 1.]], dtype=float32)
Args | |
---|---|
tensor | Tensor to copy/update. |
indices | Indices to update. |
updates | Updates to apply at the indices. |
name | Optional name for the operation. |
Returns | |
---|---|
A new tensor with the given shape and updates applied according to the indices. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/tensor_scatter_nd_update