Additional APIs for algorithms that need to be distribution-aware.
Inherits From: StrategyExtended
tf.compat.v1.distribute.StrategyExtended( container_strategy )
Note: For most usage of tf.distribute.Strategy
, there should be no need to call these methods, since TensorFlow libraries (such as optimizers) already call these methods when needed on your behalf.
Some common use cases of functions on this page:
tf.distribute.DistributedValues
can have the same locality as a distributed variable, which leads to a mirrored value residing on the same devices as the variable (as opposed to the compute devices). Such values may be passed to a call to tf.distribute.StrategyExtended.update
to update the value of a variable. You may use tf.distribute.StrategyExtended.colocate_vars_with
to give a variable the same locality as another variable. You may convert a "PerReplica" value to a variable's locality by using tf.distribute.StrategyExtended.reduce_to
or tf.distribute.StrategyExtended.batch_reduce_to
.
A distributed variable is variables created on multiple devices. As discussed in the glossary, mirrored variable and SyncOnRead variable are two examples. The standard pattern for updating distributed variables is to:
tf.distribute.Strategy.run
, compute a list of (update, variable) pairs. For example, the update might be a gradient of the loss with respect to the variable.tf.distribute.get_replica_context().merge_call()
with the updates and variables as arguments.tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)
(for one variable) or tf.distribute.StrategyExtended.batch_reduce_to
(for a list of variables) to sum the updates.tf.distribute.StrategyExtended.update(v)
for each variable to update its value.Steps 2 through 4 are done automatically by class tf.keras.optimizers.Optimizer
if you call its tf.keras.optimizers.Optimizer.apply_gradients
method in a replica context.
In fact, a higher-level solution to update a distributed variable is by calling assign
on the variable as you would do to a regular tf.Variable
. You can call the method in both replica context and cross-replica context. For a mirrored variable, calling assign
in replica context requires you to specify the aggregation
type in the variable constructor. In that case, the context switching and sync described in steps 2 through 4 are handled for you. If you call assign
on mirrored variable in cross-replica context, you can only assign a single value or assign values from another mirrored variable or a mirrored tf.distribute.DistributedValues
. For a SyncOnRead variable, in replica context, you can simply call assign
on it and no aggregation happens under the hood. In cross-replica context, you can only assign a single value to a SyncOnRead variable. One example case is restoring from a checkpoint: if the aggregation
type of the variable is tf.VariableAggregation.SUM
, it is assumed that replica values were added before checkpointing, so at the time of restoring, the value is divided by the number of replicas and then assigned to each replica; if the aggregation
type is tf.VariableAggregation.MEAN
, the value is assigned to each replica directly.
Attributes | |
---|---|
experimental_between_graph | Whether the strategy uses between-graph replication or not. This is expected to return a constant value that will not be changed throughout its life cycle. |
experimental_require_static_shapes | Returns True if static shape is required; False otherwise. |
experimental_should_init | Whether initialization is needed. |
parameter_devices | Returns the tuple of all devices used to place variables. |
should_checkpoint | Whether checkpointing is needed. |
should_save_summary | Whether saving summaries is needed. |
worker_devices | Returns the tuple of all devices used to for compute replica execution. |
batch_reduce_to
batch_reduce_to( reduce_op, value_destination_pairs, options=None )
Combine multiple reduce_to
calls into one for faster execution.
Similar to reduce_to
, but accepts a list of (value, destinations) pairs. It's more efficient than reduce each value separately.
This API currently can only be called in cross-replica context. Other variants to reduce values across replicas are:
tf.distribute.StrategyExtended.reduce_to
: the non-batch version of this API.tf.distribute.ReplicaContext.all_reduce
: the counterpart of this API in replica context. It supports both batched and non-batched all-reduce.tf.distribute.Strategy.reduce
: a more convenient method to reduce to the host in cross-replica context.See reduce_to
for more information.
@tf.function def step_fn(var): def merge_fn(strategy, value, var): # All-reduce the value. Note that `value` here is a # `tf.distribute.DistributedValues`. reduced = strategy.extended.batch_reduce_to( tf.distribute.ReduceOp.SUM, [(value, var)])[0] strategy.extended.update(var, lambda var, value: var.assign(value), args=(reduced,)) value = tf.identity(1.) tf.distribute.get_replica_context().merge_call(merge_fn, args=(value, var)) def run(strategy): with strategy.scope(): v = tf.Variable(0.) strategy.run(step_fn, args=(v,)) return v run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) MirroredVariable:{ 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> } run(tf.distribute.experimental.CentralStorageStrategy( compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> run(tf.distribute.OneDeviceStrategy("GPU:0")) <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
Args | |
---|---|
reduce_op | a tf.distribute.ReduceOp value specifying how values should be combined. Allows using string representation of the enum such as "SUM", "MEAN". |
value_destination_pairs | a sequence of (value, destinations) pairs. See tf.distribute.Strategy.reduce_to for descriptions. |
options | a tf.distribute.experimental.CommunicationOptions . Options to perform collective operations. This overrides the default options if the tf.distribute.Strategy takes one in the constructor. See tf.distribute.experimental.CommunicationOptions for details of the options. |
Returns | |
---|---|
A list of reduced values, one per pair in value_destination_pairs . |
broadcast_to
broadcast_to( tensor, destinations )
Mirror a tensor on one device to all worker devices.
Args | |
---|---|
tensor | A Tensor value to broadcast. |
destinations | A mirrored variable or device string specifying the destination devices to copy tensor to. |
Returns | |
---|---|
A value mirrored to destinations devices. |
call_for_each_replica
call_for_each_replica( fn, args=(), kwargs=None )
Run fn
once per replica.
fn
may call tf.get_replica_context()
to access methods such as replica_id_in_sync_group
and merge_call()
.
merge_call()
is used to communicate between the replicas and re-enter the cross-replica context. All replicas pause their execution having encountered a merge_call()
call. After that the merge_fn
-function is executed. Its results are then unwrapped and given back to each replica call. After that execution resumes until fn
is complete or encounters another merge_call()
. Example:
# Called once in "cross-replica" context. def merge_fn(distribution, three_plus_replica_id): # sum the values across replicas return sum(distribution.experimental_local_results(three_plus_replica_id)) # Called once per replica in `distribution`, in a "replica" context. def fn(three): replica_ctx = tf.get_replica_context() v = three + replica_ctx.replica_id_in_sync_group # Computes the sum of the `v` values across all replicas. s = replica_ctx.merge_call(merge_fn, args=(v,)) return s + v with distribution.scope(): # in "cross-replica" context ... merged_results = distribution.run(fn, args=[3]) # merged_results has the values from every replica execution of `fn`. # This statement prints a list: print(distribution.experimental_local_results(merged_results))
Args | |
---|---|
fn | function to run (will be run once per replica). |
args | Tuple or list with positional arguments for fn . |
kwargs | Dict with keyword arguments for fn . |
Returns | |
---|---|
Merged return value of fn across all replicas. |
colocate_vars_with
colocate_vars_with( colocate_with_variable )
Scope that controls which devices variables will be created on.
No operations should be added to the graph inside this scope, it should only be used when creating variables (some implementations work by changing variable creation, others work by using a tf.compat.v1.colocate_with() scope).
This may only be used inside self.scope()
.
with strategy.scope(): var1 = tf.Variable(...) with strategy.extended.colocate_vars_with(var1): # var2 and var3 will be created on the same device(s) as var1 var2 = tf.Variable(...) var3 = tf.Variable(...) def fn(v1, v2, v3): # operates on v1 from var1, v2 from var2, and v3 from var3 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there # too. strategy.extended.update(var1, fn, args=(var2, var3))
Args | |
---|---|
colocate_with_variable | A variable created in this strategy's scope() . Variables created while in the returned context manager will be on the same set of devices as colocate_with_variable . |
Returns | |
---|---|
A context manager. |
experimental_make_numpy_dataset
experimental_make_numpy_dataset( numpy_input, session=None )
Makes a dataset for input provided via a numpy array.
This avoids adding numpy_input
as a large constant in the graph, and copies the data to the machine or machines that will be processing the input.
Args | |
---|---|
numpy_input | A nest of NumPy input arrays that will be distributed evenly across all replicas. Note that lists of Numpy arrays are stacked, as that is normal tf.data.Dataset behavior. |
session | (TensorFlow v1.x graph execution only) A session used for initialization. |
Returns | |
---|---|
A tf.data.Dataset representing numpy_input . |
experimental_run_steps_on_iterator
experimental_run_steps_on_iterator( fn, iterator, iterations=1, initial_loop_values=None )
DEPRECATED: please use run
instead.
Run fn
with input from iterator
for iterations
times.
This method can be used to run a step function for training a number of times using input from a dataset.
Args | |
---|---|
fn | function to run using this distribution strategy. The function must have the following signature: def fn(context, inputs) . context is an instance of MultiStepContext that will be passed when fn is run. context can be used to specify the outputs to be returned from fn by calling context.set_last_step_output . It can also be used to capture non tensor outputs by context.set_non_tensor_output . See MultiStepContext documentation for more information. inputs will have same type/structure as iterator.get_next() . Typically, fn will use call_for_each_replica method of the strategy to distribute the computation over multiple replicas. |
iterator | Iterator of a dataset that represents the input for fn . The caller is responsible for initializing the iterator as needed. |
iterations | (Optional) Number of iterations that fn should be run. Defaults to 1. |
initial_loop_values | (Optional) Initial values to be passed into the loop that runs fn . Defaults to None . initial_loop_values argument when we have a mechanism to infer the outputs of fn . |
Returns | |
---|---|
Returns the MultiStepContext object which has the following properties, among other things:
|
non_slot_devices
non_slot_devices( var_list )
Device(s) for non-slot variables.
DEPRECATED: TF 1.x ONLY.
This method returns non-slot devices where non-slot variables are placed. Users can create non-slot variables on these devices by using a block:
with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)): ...
Args | |
---|---|
var_list | The list of variables being optimized, needed with the default tf.distribute.Strategy . |
Returns | |
---|---|
A sequence of devices for non-slot variables. |
read_var
read_var( v )
Reads the value of a variable.
Returns the aggregate value of a replica-local variable, or the (read-only) value of any other variable.
Args | |
---|---|
v | A variable allocated within the scope of this tf.distribute.Strategy . |
Returns | |
---|---|
A tensor representing the value of v , aggregated across replicas if necessary. |
reduce_to
reduce_to( reduce_op, value, destinations, options=None )
Combine (via e.g. sum or mean) values across replicas.
reduce_to
aggregates tf.distribute.DistributedValues
and distributed variables. It supports both dense values and tf.IndexedSlices
.
This API currently can only be called in cross-replica context. Other variants to reduce values across replicas are:
tf.distribute.StrategyExtended.batch_reduce_to
: the batch version of this API.tf.distribute.ReplicaContext.all_reduce
: the counterpart of this API in replica context. It supports both batched and non-batched all-reduce.tf.distribute.Strategy.reduce
: a more convenient method to reduce to the host in cross-replica context.destinations
specifies where to reduce the value to, e.g. "GPU:0". You can also pass in a Tensor
, and the destinations will be the device of that tensor. For all-reduce, pass the same to value
and destinations
.
It can be used in tf.distribute.ReplicaContext.merge_call
to write code that works for all tf.distribute.Strategy
.
@tf.function def step_fn(var): def merge_fn(strategy, value, var): # All-reduce the value. Note that `value` here is a # `tf.distribute.DistributedValues`. reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM, value, destinations=var) strategy.extended.update(var, lambda var, value: var.assign(value), args=(reduced,)) value = tf.identity(1.) tf.distribute.get_replica_context().merge_call(merge_fn, args=(value, var)) def run(strategy): with strategy.scope(): v = tf.Variable(0.) strategy.run(step_fn, args=(v,)) return v run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) MirroredVariable:{ 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> } run(tf.distribute.experimental.CentralStorageStrategy( compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> run(tf.distribute.OneDeviceStrategy("GPU:0")) <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
Args | |
---|---|
reduce_op | a tf.distribute.ReduceOp value specifying how values should be combined. Allows using string representation of the enum such as "SUM", "MEAN". |
value | a tf.distribute.DistributedValues , or a tf.Tensor like object. |
destinations | a tf.distribute.DistributedValues , a tf.Variable , a tf.Tensor alike object, or a device string. It specifies the devices to reduce to. To perform an all-reduce, pass the same to value and destinations . Note that if it's a tf.Variable , the value is reduced to the devices of that variable, and this method doesn't update the variable. |
options | a tf.distribute.experimental.CommunicationOptions . Options to perform collective operations. This overrides the default options if the tf.distribute.Strategy takes one in the constructor. See tf.distribute.experimental.CommunicationOptions for details of the options. |
Returns | |
---|---|
A tensor or value reduced to destinations . |
update
update( var, fn, args=(), kwargs=None, group=True )
Run fn
to update var
using inputs mirrored to the same devices.
tf.distribute.StrategyExtended.update
takes a distributed variable var
to be updated, an update function fn
, and args
and kwargs
for fn
. It applies fn
to each component variable of var
and passes corresponding values from args
and kwargs
. Neither args
nor kwargs
may contain per-replica values. If they contain mirrored values, they will be unwrapped before calling fn
. For example, fn
can be assign_add
and args
can be a mirrored DistributedValues where each component contains the value to be added to this mirrored variable var
. Calling update
will call assign_add
on each component variable of var
with the corresponding tensor value on that device.
strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 devices with strategy.scope(): v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) def update_fn(v): return v.assign(1.0) result = strategy.extended.update(v, update_fn) # result is # Mirrored:{ # 0: tf.Tensor(1.0, shape=(), dtype=float32), # 1: tf.Tensor(1.0, shape=(), dtype=float32) # }
If var
is mirrored across multiple devices, then this method implements logic as following:
results = {} for device, v in var: with tf.device(device): # args and kwargs will be unwrapped if they are mirrored. results[device] = fn(v, *args, **kwargs) return merged(results)
Otherwise, this method returns fn(var, *args, **kwargs)
colocated with var
.
Args | |
---|---|
var | Variable, possibly mirrored to multiple devices, to operate on. |
fn | Function to call. Should take the variable as the first argument. |
args | Tuple or list. Additional positional arguments to pass to fn() . |
kwargs | Dict with keyword arguments to pass to fn() . |
group | Boolean. Defaults to True. If False, the return value will be unwrapped. |
Returns | |
---|---|
By default, the merged return value of fn across all replicas. The merged result has dependencies to make sure that if it is evaluated at all, the side effects (updates) will happen on every replica. If instead "group=False" is specified, this function will return a nest of lists where each list has an element per replica, and the caller is responsible for ensuring all elements are executed. |
update_non_slot
update_non_slot( colocate_with, fn, args=(), kwargs=None, group=True )
Runs fn(*args, **kwargs)
on colocate_with
devices.
Used to update non-slot variables.
DEPRECATED: TF 1.x ONLY.
Args | |
---|---|
colocate_with | Devices returned by non_slot_devices() . |
fn | Function to execute. |
args | Tuple or list. Positional arguments to pass to fn() . |
kwargs | Dict with keyword arguments to pass to fn() . |
group | Boolean. Defaults to True. If False, the return value will be unwrapped. |
Returns | |
---|---|
Return value of fn , possibly merged across devices. |
value_container
value_container( value )
Returns the container that this per-replica value
belongs to.
Args | |
---|---|
value | A value returned by run() or a variable created in scope() . |
Returns | |
---|---|
A container that value belongs to. If value does not belong to any container (including the case of container having been destroyed), returns the value itself. value in experimental_local_results(value_container(value)) will always be true. |
variable_created_in_scope
variable_created_in_scope( v )
Tests whether v
was created while this strategy scope was active.
Variables created inside the strategy scope are "owned" by it:
strategy = tf.distribute.MirroredStrategy() with strategy.scope(): v = tf.Variable(1.) strategy.extended.variable_created_in_scope(v) True
Variables created outside the strategy are not owned by it:
strategy = tf.distribute.MirroredStrategy() v = tf.Variable(1.) strategy.extended.variable_created_in_scope(v) False
Args | |
---|---|
v | A tf.Variable instance. |
Returns | |
---|---|
True if v was created inside the scope, False if not. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/compat/v1/distribute/StrategyExtended