An Op to exchange data across TPU replicas.
tf.raw_ops.AllToAll( input, group_assignment, concat_dimension, split_dimension, split_count, name=None )
On each replica, the input is split into split_count
blocks along split_dimension
and send to the other replicas given group_assignment. After receiving split_count
- 1 blocks from other replicas, we concatenate the blocks along concat_dimension
as the output.
For example, suppose there are 2 TPU replicas: replica 0 receives input: [[A, B]]
replica 1 receives input: [[C, D]]
group_assignment=[[0, 1]]
concat_dimension=0 split_dimension=1 split_count=2
replica 0's output: [[A], [C]]
replica 1's output: [[B], [D]]
Args | |
---|---|
input | A Tensor . Must be one of the following types: float32 , float64 , int32 , uint8 , int16 , int8 , complex64 , int64 , qint8 , quint8 , qint32 , bfloat16 , uint16 , complex128 , half , uint32 , uint64 , bool . The local input to the sum. |
group_assignment | A Tensor of type int32 . An int32 tensor with shape [num_groups, num_replicas_per_group]. group_assignment[i] represents the replica ids in the ith subgroup. |
concat_dimension | An int . The dimension number to concatenate. |
split_dimension | An int . The dimension number to split. |
split_count | An int . The number of splits, this number must equal to the sub-group size(group_assignment.get_shape()[1]) |
name | A name for the operation (optional). |
Returns | |
---|---|
A Tensor . Has the same type as input . |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/raw_ops/AllToAll