Represents the layout information of a DTensor.
tf.experimental.dtensor.Layout(
sharding_specs: List[str],
mesh: tf.experimental.dtensor.Mesh
)
| Used in the guide | Used in the tutorials |
|---|---|
A layout describes how a distributed tensor is partitioned across a mesh (and thus across devices). For each axis of the tensor, the corresponding sharding spec indicates which dimension of the mesh it is sharded over. A special sharding spec UNSHARDED indicates that axis is replicated on all the devices of that mesh.
Refer to DTensor Concepts for in depth discussion and examples.
For example, let's consider a 1-D mesh:
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)])
This mesh arranges 6 TPU devices into a 1-D array. Layout([UNSHARDED], mesh) is a layout for rank-1 tensor which is replicated on the 6 devices.
For another example, let's consider a 2-D mesh:
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"],
[("x", 3), ("y", 2)])
This mesh arranges 6 TPU devices into a 3x2 2-D array. Layout(["x", UNSHARDED], mesh) is a layout for rank-2 tensor whose first axis is sharded on mesh dimension "x" and the second axis is replicated. If we place np.arange(6).reshape((3, 2)) using this layout, the individual components tensors would look like:
Device | Component TPU:0 [[0, 1]] TPU:1 [[0, 1]] TPU:2 [[2, 3]] TPU:3 [[2, 3]] TPU:4 [[4, 5]] TPU:5 [[4, 5]]
| Args | |
|---|---|
sharding_specs | List of sharding specifications, each corresponding to a tensor axis. Each specification (dim_sharding) can either be a mesh dimension or the special value UNSHARDED. |
mesh | A mesh configuration for the Tensor. |
| Attributes | |
|---|---|
mesh | |
rank | |
shape | |
sharding_specs | |
type | |
as_protoas_proto()
as_proto(self: tensorflow.python._pywrap_dtensor_device.Layout) -> tensorflow::dtensor::LayoutProto
Returns the LayoutProto protobuf message.
batch_sharded@classmethod
batch_sharded(
mesh: tf.experimental.dtensor.Mesh,
batch_dim: str,
rank: int,
axis: int = 0
) -> 'Layout'
Returns a layout sharded on batch dimension.
deletedelete(
dims: List[int]
) -> 'Layout'
Returns the layout with the give dimensions deleted.
from_device@classmethod
from_device(
device: str
) -> 'Layout'
Constructs a single device layout from a single device mesh.
from_proto@classmethod
from_proto(
layout_proto: layout_pb2.LayoutProto
) -> 'Layout'
Creates an instance from a LayoutProto.
from_single_device_mesh@classmethod
from_single_device_mesh(
mesh: tf.experimental.dtensor.Mesh
) -> 'Layout'
Constructs a single device layout from a single device mesh.
from_string@classmethod
from_string(
layout_str: str
) -> 'Layout'
Creates an instance from a human-readable string.
global_shape_from_local_shapeglobal_shape_from_local_shape()
global_shape_from_local_shape(self: tensorflow.python._pywrap_dtensor_device.Layout, local_shape: List[int]) -> tuple
Returns the global shape computed from this local shape.
inner_sharded@classmethod
inner_sharded(
mesh: tf.experimental.dtensor.Mesh,
inner_dim: str,
rank: int
) -> 'Layout'
Returns a layout sharded on inner dimension.
is_batch_parallelis_batch_parallel()
is_batch_parallel(self: tensorflow.python._pywrap_dtensor_device.Layout) -> bool
is_fully_replicatedis_fully_replicated()
is_fully_replicated(self: tensorflow.python._pywrap_dtensor_device.Layout) -> bool
Returns True if all tensor axes are replicated.
is_single_deviceis_single_device()
is_single_device(self: tensorflow.python._pywrap_dtensor_device.Layout) -> bool
Returns True if the Layout represents a non-distributed device.
local_shape_from_global_shapelocal_shape_from_global_shape()
local_shape_from_global_shape(self: tensorflow.python._pywrap_dtensor_device.Layout, global_shape: List[int]) -> tuple
Returns the local shape computed from this global shape.
num_shardsnum_shards()
num_shards(self: tensorflow.python._pywrap_dtensor_device.Layout, idx: int) -> int
Returns the number of shards for tensor dimension idx.
offset_to_shardoffset_to_shard()
Mapping from offset in a flattened list to shard index.
offset_tuple_to_global_indexoffset_tuple_to_global_index(
offset_tuple
)
Mapping from offset to index in global tensor.
replicated@classmethod
replicated(
mesh: tf.experimental.dtensor.Mesh,
rank: int
) -> 'Layout'
Returns a replicated layout of rank rank.
to_partedto_parted() -> 'Layout'
Returns a "parted" layout from a static layout.
A parted layout contains axes that are treated as independent by most of SPMD expanders.
FIXME(b/285905569): The exact semantics is still being investigated.
to_stringto_string()
to_string(self: tensorflow.python._pywrap_dtensor_device.Layout) -> str
__eq____eq__()
eq(self: tensorflow.python._pywrap_dtensor_device.Layout, arg0: tensorflow.python._pywrap_dtensor_device.Layout) -> bool
© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Layout