Represents a Mesh configuration over a certain list of Mesh Dimensions.
tf.experimental.dtensor.Mesh(
dim_names: List[str],
global_device_ids: np.ndarray,
local_device_ids: List[int],
local_devices: List[Union[tf_device.DeviceSpec, str]],
mesh_name: str = '',
global_devices: Optional[List[Union[tf_device.DeviceSpec, str]]] = None,
use_xla_spmd: bool = USE_XLA_SPMD
)
A mesh consists of named dimensions with sizes, which describe how a set of devices are arranged. Defining tensor layouts in terms of mesh dimensions allows us to efficiently determine the communication required when computing an operation with tensors of different layouts.
A mesh provides information not only about the placement of the tensors but also the topology of the underlying devices. For example, we can group 8 TPUs as a 1-D array for data parallelism or a 2x4 grid for (2-way) data parallelism and (4-way) model parallelism.
Refer to DTensor Concepts for in depth discussion and examples.
Note: the utilitiesdtensor.create_meshanddtensor.create_distributed_meshprovide a simpler API to create meshes for single- or multi-client use cases.
| Args | |
|---|---|
dim_names | A list of strings indicating dimension names. |
global_device_ids | An ndarray of global device IDs is used to compose DeviceSpecs describing the mesh. The shape of this array determines the size of each mesh dimension. Values in this array should increment sequentially from 0. This argument is the same for every DTensor client. |
local_device_ids | A list of local device IDs equal to a subset of values in global_device_ids. They indicate the position of local devices in the global mesh. Different DTensor clients must contain distinct local_device_ids contents. All local_device_ids from all DTensor clients must cover every element in global_device_ids. |
local_devices | The list of devices hosted locally. The elements correspond 1:1 to those of local_device_ids. |
mesh_name | The name of the mesh. Currently, this is rarely used, and is mostly used to indicate whether it is a CPU, GPU, or TPU-based mesh. |
global_devices | optional The list of global devices. Set when multiple device meshes are in use. |
use_xla_spmd | optional Boolean when True, will use XLA SPMD instead of DTensor SPMD. |
| Attributes | |
|---|---|
dim_names | |
name | |
single_device | |
size | |
strides | Returns the strides tensor array for this mesh. If the mesh shape is [(device_id / (b*c*d)) % a, (device_id / (c*d)) % b, (device_id / (d)) % c, (device_id) % d] This is the same as |
as_protoas_proto()
as_proto(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> tensorflow::dtensor::MeshProto
Returns the MeshProto protobuf message.
contains_dimcontains_dim()
contains_dim(self: tensorflow.python._pywrap_dtensor_device.Mesh, dim_name: str) -> bool
Returns True if a Mesh contains the given dimension name.
coordscoords(
device_idx: int
) -> tf.Tensor
Converts the device index into a tensor of mesh coordinates.
device_locationdevice_location()
device_location(self: tensorflow.python._pywrap_dtensor_device.Mesh, arg0: int) -> List[int]
device_typedevice_type()
device_type(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> str
Returns the device_type of a Mesh.
dim_sizedim_size()
dim_size(self: tensorflow.python._pywrap_dtensor_device.Mesh, dim_name: str) -> int
Returns the size of mesh dimension.
from_device@classmethod
from_device(
device: str
) -> 'Mesh'
Constructs a single device mesh from a device string.
from_proto@classmethod
from_proto(
proto: layout_pb2.MeshProto
) -> 'Mesh'
Construct a mesh instance from input proto.
from_string@classmethod
from_string(
mesh_str: str
) -> 'Mesh'
global_device_idsglobal_device_ids() -> np.ndarray
Returns a global device list as an array.
global_devicesglobal_devices()
global_devices(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> List[str]
Returns a list of global device specs represented as strings.
host_meshhost_mesh() -> 'Mesh'
Returns a host mesh.
is_remoteis_remote()
is_remote(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> bool
Returns True if a Mesh contains only remote devices.
is_single_deviceis_single_device()
is_single_device(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> bool
Returns True if the mesh represents a non-distributed device.
local_device_idslocal_device_ids()
local_device_ids(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> Span[int]
Returns a list of local device IDs.
local_device_locationslocal_device_locations() -> List[Dict[str, int]]
Returns a list of local device locations.
A device location is a dictionary from dimension names to indices on those dimensions.
local_deviceslocal_devices()
local_devices(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> Span[str]
Returns a list of local device specs represented as strings.
min_global_device_idmin_global_device_id()
min_global_device_id(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> int
Returns the minimum global device ID.
num_local_devicesnum_local_devices()
num_local_devices(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> int
Returns the number of local devices.
shapeshape()
shape(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> List[int]
Returns the shape of the mesh.
to_stringto_string()
to_string(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> str
Returns string representation of Mesh.
unravel_indexunravel_index()
Returns a dictionary from device ID to {dim_name: dim_index}.
For example, for a 3x2 mesh, return this:
{ 0: {'x': 0, 'y', 0},
1: {'x': 0, 'y', 1},
2: {'x': 1, 'y', 0},
3: {'x': 1, 'y', 1},
4: {'x': 2, 'y', 0},
5: {'x': 2, 'y', 1} }
use_xla_spmduse_xla_spmd()
use_xla_spmd(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> bool
Returns True if Mesh will use XLA for SPMD instead of DTensor SPMD.
__contains____contains__()
contains(self: tensorflow.python._pywrap_dtensor_device.Mesh, dim_name: str) -> bool
__eq____eq__()
eq(self: tensorflow.python._pywrap_dtensor_device.Mesh, arg0: tensorflow.python._pywrap_dtensor_device.Mesh) -> bool
__getitem____getitem__(
dim_name: str
) -> MeshDimension
© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Mesh