DistributionStrategy
Defined in tensorflow/python/training/distribute.py
.
A list of devices with a state & compute distribution policy.
The intent is that you can write an algorithm in a stylized way and it will be usable with a variety of different DistributionStrategy
implementations. Each descendant will implement a different strategy for distributing the algorithm across multiple devices/machines. Furthermore, these changes can be hidden inside the specific layers and other library classes that need special treatment to run in a distributed setting, so that most users' model definition code can run unchanged. The DistributionStrategy
API works the same way with eager and graph execution.
First let's introduce a few high-level concepts:
To distribute an algorithm, we might use some of these ingredients:
We have then a few approaches we want to support: Code written (as if) with no knowledge of class DistributionStrategy
. This code should work as before, even if some of the layers, etc. used by that code are written to be distribution-aware. This is done by having a default DistributionStrategy
that gives ordinary behavior, and by default being in a single tower context. Ordinary model code that you want to run using a specific DistributionStrategy
. This can be as simple as:
with my_distribution.scope(): iterator = my_distribution.distribute_dataset(dataset) tower_train_ops = my_distribution.call_for_each_tower( tower_fn, iterator.get_next()) train_op = tf.group(my_distribution.unwrap(tower_train_ops))
This takes an ordinary dataset
and tower_fn
and runs it distributed using a particular DistributionStrategy
in my_distribution
. Any variables created in tower_fn
are created using my_distribution
's policy, and library functions called by tower_fn
can use the get_tower_context()
API to get enhanced behavior in this case.
Note that in the future we will add support for initializable Dataset iterators, at which point this example code will change.
DistributionStrategy
APIs inside a with my_distribution.scope():
block of code.Lower-level concepts:
fn
on multiple devices, like call_for_each_tower(fn, w)
with an argument w
that is a wrapped value. This means w
will have a map taking tower device d0
to w0
, tower device d1
to w1
, etc. call_for_each_tower()
unwraps w
before calling fn
, so it calls fn(w0)
on d0
, fn(w1)
on d1
, etc. It then merges the return values from fn()
, which can possibly result in wrapped values. For example, let's say fn()
returns a tuple with three components: (x, a, v0)
from tower 0, (x, b, v1)
on tower 1, etc. If the first component is the same object x
from every tower, then the first component of the merged result will also be x
. If the second component is different (a
, b
, ...) from each tower, then the merged value will have a wrapped map from tower device to the different values. If the third component is the members of a mirrored variable (v
maps d0
to v0
, d1
to v1
, etc.), then the merged result will be that mirrored variable (v
).DistributionStrategy
methods which operate across the towers (like reduce()
). By default you start in a tower context (the default "single tower context") and then some methods can switch you back and forth, as described below.colocate_vars_with()
to get the remaining non-slot variables on the same device. Otherwise you can use non_slot_devices()
to pick a consistent set of devices to pass to both colocate_vars_with()
and update_non_slot()
.When using a DistributionStrategy
, we have a new type dimension called locality that says what values are compatible with which APIs:
v
): value is "mirrored" across all the devices which have a copy of variable v
(also a Mirrored-wrapped value, but over parameter devices instead of worker devices).Rules for methods with respect to locality and single-tower vs. cross-tower context:
with d.scope()
: default single-tower context -> cross-tower context for d
with d.colocate_vars_with(v)
: in tower/cross-tower context, variables will be created with locality V(v
). That is, if we write with d.colocate_vars_with(v1): v2 = tf.get_variable(...)
, then v2
will have locality V(v1
), i.e. locality V(v2
) will equal V(v1
).with d.colocate_vars_with(d.non_slot_devices(...))
: in tower/cross-tower context, variables will be created with locality Nv = tf.get_variable(...)
: in tower/cross-tower context, creates a variable (which by definition will have locality V(v
), though will match another locality if inside a colocate_vars_with
scope).d.distribute_dataset(dataset)
: in cross-tower context, produces an iterator with locality Td.broadcast(t)
: in cross-tower context, produces a value with locality Md.broadcast(t, v)
: in cross-tower context, produces a value with locality V(v
)d.call_for_each_tower(fn, ...)
: in cross-tower context, runs fn()
in a tower context (and so may call get_tower_context()
and use its API, including merge_call()
to get back to cross-tower context), once for each tower. May use values with locality T or M, and any variable.d.reduce(m, t)
: in cross-tower context, accepts t with locality T and produces a value with locality M.d.reduce(m, t, v)
: in cross-tower context, accepts t with locality T and produces a value with locality V(v
).d.batch_reduce(m, [(t, v)]): see
d.reduce()`d.update(v, fn, ...)
: in cross-tower context, runs fn()
once for each device v
is copied to, all inputs should have locality V(v
), output will have locality V(v
) as well.d.update_non_slot(d.non_slot_devices(), fn)
: in cross-tower context, like d.update()
except with locality N.d.fetch(t)
: Copy t
with any locality to the client's CPU device.The standard pattern for updating variables is to:
d.distribute_dataset()
.d.call_for_each_tower()
up to the point of getting a list of gradient, variable pairs.d.reduce("sum", t, v)
or d.batch_reduce()
to sum the gradients (with locality T) into values with locality V(v
).d.update(v)
for each variable to update its value.Steps 3 and 4 are done automatically by class Optimizer
if you call its apply_gradients
method in a tower context. Otherwise you can manually call its _distributed_apply
method in a cross-tower context.
Another thing you might want to do in the middle of your tower function is an all-reduce of some intermediate value, using d.reduce()
or d.batch_reduce()
without supplying a variable as the destination.
Layers should expect to be called in a tower context, and can use the get_tower_context()
function to get a TowerContext
object. The TowerContext
object has a merge_call()
method for entering cross-tower context where you can use reduce()
(or batch_reduce()
) and then optionally update()
to update state.
You may use this API whether or not a DistributionStrategy
is being used, since there is a default implementation of TowerContext
and DistributionStrategy
. Or you can use the get_tower_context().is_single_tower
property to run different code in the distributed vs. single tower cases.
is_single_tower
Returns whether there is a single tower or multiple.
A boolean. If True
, call_for_each_tower(fn)
will only call fn
once. If False
, call_for_each_tower(fn)
may call fn
multiple times.
num_towers
Returns number of towers, for purposes of averaging across towers.
parameter_devices
Returns the list of devices used for variable and update
placement.
worker_device_index
An object mapping worker device to an id.
This might be passed as an argument to call_for_each_tower()
, as in:
with distribution_strategy.scope(): def fn(device_id): # device_id is an integer. `fn` is being executed on device: # distribution_strategy.worker_devices[device_id]. distribution_strategy.call_for_each_tower( fn, distribution_strategy.worker_device_index)
An index object, or the integer 0 if there is only a single tower.
worker_devices
Returns the list of devices used to run call_for_each_tower()
calls.
batch_reduce
batch_reduce( method_string, value_destination_pairs )
Combine multiple reduce
calls into one for faster execution.
method_string
: A string indicating how to combine values, either "sum" or "mean".value_destination_pairs
: A sequence of (value, destinations) pairs. See reduce()
for a description.A list of mirrored values, one per pair in value_destination_pairs
.
broadcast
broadcast( tensor, destinations=None )
Mirror a tensor on one device to all worker devices.
tensor
: A Tensor value to broadcast.destinations
: An optional mirrored variable, device string, or list of device strings, specifying the destination devices to copy tensor
to. Defaults to self.worker_devices
.A value mirrored to destinations
devices.
call_for_each_tower
call_for_each_tower( fn, *args, **kwargs )
Run fn
once per tower.
fn
may call tf.get_tower_context()
to access methods such as tower_id()
and merge_call()
.
merge_call()
is used to communicate betwen the towers and re-enter the cross-tower context. All towers pause their execution having encountered a merge_call()
call. After that the merge_fn
-function is executed. Its results are then unwrapped and given back to each tower call. After that execution resumes until fn
is complete or encounters another merge_call()
. Example:
# Called once in "cross-tower" context. def merge_fn(distribution, three_plus_tower_id): # sum the values across towers return sum(distribution.unwrap(three_plus_tower_id)) # Called once per tower in `distribution`, in a "tower" context. def fn(three): tower_ctx = tf.get_tower_context() v = three + tower_ctx.tower_id # Computes the sum of the `v` values across all towers. s = tower_ctx.merge_call(merge_fn, v) return s + v with distribution.scope(): # in "cross-tower" context ... merged_results = distribution.call_for_each_tower(fn, 3) # merged_results has the values from every tower execution of `fn`. print(distribution.unwrap(merged_results)) # Prints a list
fn
: function to run (will be run once per tower).*args
: positional arguments for fn
**kwargs
: keyword arguments for fn
. "run_concurrently"
: Boolean indicating whether executions of fn
can be run concurrently (under eager execution only), defaults to True
.Merged return value of fn
across all towers.
colocate_vars_with
colocate_vars_with(colocate_with_variable)
Scope that controls which devices variables will be created on.
No operations should be added to the graph inside this scope, it should only be used when creating variables (some implementations work by changing variable creation, others work by using a tf.colocate_with() scope).
This may only be used inside self.scope()
.
Example usage:
with distribution_strategy.scope(): var1 = tf.get_variable(...) with distribution_strategy.colocate_vars_with(v1): # var2 and var3 will be created on the same device(s) as var1 var2 = tf.get_variable(...) var3 = tf.get_variable(...) def fn(v1, v2, v3): # operates on v1 from var1, v2 from var2, and v3 from var3 # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too. distribution_strategy.update(v1, fn, v2, v3)
colocate_with_variable
: A created in self.scope()
. Variables created while in the returned context manager will be on the same set of devices as colocate_with_variable
.A context manager.
configure
configure(session_config=None)
Find the best configuration given a tensorflow session config.
distribute_dataset
distribute_dataset(dataset)
Return an iterator into dataset
split across all towers.
Suitable for providing input to for call_for_each_tower()
, as in:
with distribution_strategy.scope(): iterator = distribution_strategy.distribute_dataset(dataset) tower_results = distribution_strategy.call_for_each_tower( tower_fn, iterator.get_next())
dataset
: A tf.data.Dataset
.A Dataset iterator that will produce separate splits for each tower.
fetch
fetch( val, destination='/device:CPU:0', fn=(lambda x: x) )
Return a copy of val
or fn(val)
on destination
.
This is useful for getting a mirrored value onto a device. It will attempt to avoid a copy by checking if the value is already on the destination device.
val
: Value (which may be mirrored) to copy.destination
: A device string to copy the value to.fn
: An optional function to apply to the value on the source device, before copying.A Tensor
on destination
.
group
group( value, name=None )
Shortcut for tf.group(distribution.unwrap(value))
.
non_slot_devices
non_slot_devices(var_list)
Device(s) for non-slot variables.
Create variables on these devices in a with colocate_vars_with(non_slot_devices(...)):
block. Update those using update_non_slot()
.
var_list
: The list of variables being optimized, needed with the default DistributionStrategy
.reduce
reduce( method_string, value, destinations=None )
Combine (via e.g. sum or mean) values across towers.
method_string
: A string indicating how to combine values, either "sum" or "mean".value
: A per-device value with one value per tower.destinations
: An optional mirrored variable, a device string, list of device strings. The return value will be copied to all destination devices (or all the devices where the mirrored variable resides). If None
or unspecified, the destinations will match the devices value
resides on.A value mirrored to destinations
.
scope
scope()
Returns a context manager selecting this DistributionStrategy as current.
Inside a with distribution_strategy.scope():
code block, this thread will use a variable creator set by distribution_strategy
, and will enter its "cross-tower context".
A context manager.
tower_local_var_scope
tower_local_var_scope(reduce_method)
Inside this scope, new variables will not be mirrored.
There will still be one component variable per tower, but there is no requirement that they stay in sync. Instead, when saving them or calling fetch()
, we use the value that results when calling reduce()
on all the towers' variables.
Note: tower-local implies not trainable. Instead, it is expected that each tower will directly update (usingassign_add()
or whatever) its local variable instance but only the aggregated value (accessible usingfetch()
) will be exported from the model. When it is acceptable to only aggregate on export, we greatly reduce communication overhead by using tower-local variables.
Note: All component variables will be initialized to the same value, using the initialization expression from the first tower. The values will match even if the initialization expression uses random numbers.
reduce_method
: String used as a method_string
to reduce()
to get the value to save when checkpointing.A context manager.
unwrap
unwrap(value)
Returns the list of all per-device values contained in value
.
value
: A value returned by call_for_each_tower()
or a variable created in scope()
.A list of values contained in value
. If value
represents a single value, this returns [value].
update
update( var, fn, *args, **kwargs )
Run fn
to update var
using inputs mirrored to the same devices.
If var
is mirrored across multiple devices, then this implements logic like:
results = {} for device, v in var: with tf.device(device): # *args and **kwargs will be unwrapped if they are mirrored. results[device] = fn(v, *args, **kwargs) return merged(results)
Otherwise this returns fn(var, *args, **kwargs)
colocated with var
.'
Neither args nor *kwargs may contain per-device values. If they contain mirrored values, they will be unwrapped before calling fn
.
var
: Variable, possibly mirrored to multiple devices, to operate on.fn
: Function to call. Should take the variable as the first argument.*args
: Additional positional arguments to pass to fn()
.**kwargs
: Keyword arguments to pass to fn()
.Merged return value of fn
across all towers.
update_non_slot
update_non_slot( colocate_with, fn, *args, **kwargs )
Runs fn(*args, **kwargs)
on colocate_with
devices.
colocate_with
: The return value of non_slot_devices()
.fn
: Function to execute.*args
: Positional arguments to pass to fn()
.**kwargs
: Keyword arguments to pass to fn()
.Return value of fn
, possibly merged across devices.
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy