W3cubDocs

/TensorFlow Python

tf.contrib.distribute.MirroredStrategy

Class MirroredStrategy

Inherits From: DistributionStrategy

Defined in tensorflow/contrib/distribute/python/mirrored_strategy.py.

Mirrors vars to distribute across multiple devices on a single machine.

This strategy uses one tower per device and sync replication.

Properties

is_single_tower

Returns whether there is a single tower or multiple.

Returns:

A boolean. If True, call_for_each_tower(fn) will only call fn once. If False, call_for_each_tower(fn) may call fn multiple times.

num_towers

Returns number of towers, for purposes of averaging across towers.

parameter_devices

Returns the list of devices used for variable and update placement.

worker_device_index

An object mapping worker device to an id.

This might be passed as an argument to call_for_each_tower(), as in:

with distribution_strategy.scope():

  def fn(device_id):
    # device_id is an integer. `fn` is being executed on device:
    #    distribution_strategy.worker_devices[device_id].

  distribution_strategy.call_for_each_tower(
      fn, distribution_strategy.worker_device_index)

Returns:

An index object, or the integer 0 if there is only a single tower.

worker_devices

Returns the list of devices used to run call_for_each_tower() calls.

Methods

__init__

__init__(
    devices=None,
    num_gpus=None,
    cross_tower_ops=None,
    prefetch_on_device=None
)

Initialize self. See help(type(self)) for accurate signature.

batch_reduce

batch_reduce(
    method_string,
    value_destination_pairs
)

Combine multiple reduce calls into one for faster execution.

Args:

  • method_string: A string indicating how to combine values, either "sum" or "mean".
  • value_destination_pairs: A sequence of (value, destinations) pairs. See reduce() for a description.

Returns:

A list of mirrored values, one per pair in value_destination_pairs.

broadcast

broadcast(
    tensor,
    destinations=None
)

Mirror a tensor on one device to all worker devices.

Args:

  • tensor: A Tensor value to broadcast.
  • destinations: An optional mirrored variable, device string, or list of device strings, specifying the destination devices to copy tensor to. Defaults to self.worker_devices.

Returns:

A value mirrored to destinations devices.

call_for_each_tower

call_for_each_tower(
    fn,
    *args,
    **kwargs
)

Run fn once per tower.

fn may call tf.get_tower_context() to access methods such as tower_id() and merge_call().

merge_call() is used to communicate betwen the towers and re-enter the cross-tower context. All towers pause their execution having encountered a merge_call() call. After that the merge_fn-function is executed. Its results are then unwrapped and given back to each tower call. After that execution resumes until fn is complete or encounters another merge_call(). Example:

# Called once in "cross-tower" context.
def merge_fn(distribution, three_plus_tower_id):
  # sum the values across towers
  return sum(distribution.unwrap(three_plus_tower_id))

# Called once per tower in `distribution`, in a "tower" context.
def fn(three):
  tower_ctx = tf.get_tower_context()
  v = three + tower_ctx.tower_id
  # Computes the sum of the `v` values across all towers.
  s = tower_ctx.merge_call(merge_fn, v)
  return s + v

with distribution.scope():
  # in "cross-tower" context
  ...
  merged_results = distribution.call_for_each_tower(fn, 3)
  # merged_results has the values from every tower execution of `fn`.
  print(distribution.unwrap(merged_results))  # Prints a list

Args:

  • fn: function to run (will be run once per tower).
  • *args: positional arguments for fn
  • **kwargs: keyword arguments for fn. "run_concurrently": Boolean indicating whether executions of fn can be run concurrently (under eager execution only), defaults to True.

Returns:

Merged return value of fn across all towers.

colocate_vars_with

colocate_vars_with(colocate_with_variable)

Scope that controls which devices variables will be created on.

No operations should be added to the graph inside this scope, it should only be used when creating variables (some implementations work by changing variable creation, others work by using a tf.colocate_with() scope).

This may only be used inside self.scope().

Example usage:

with distribution_strategy.scope():
  var1 = tf.get_variable(...)
  with distribution_strategy.colocate_vars_with(v1):
    # var2 and var3 will be created on the same device(s) as var1
    var2 = tf.get_variable(...)
    var3 = tf.get_variable(...)

  def fn(v1, v2, v3):
    # operates on v1 from var1, v2 from var2, and v3 from var3

  # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
  distribution_strategy.update(v1, fn, v2, v3)

Args:

  • colocate_with_variable: A created in self.scope(). Variables created while in the returned context manager will be on the same set of devices as colocate_with_variable.

Returns:

A context manager.

configure

configure(session_config=None)

Find the best configuration given a tensorflow session config.

distribute_dataset

distribute_dataset(dataset)

Return an iterator into dataset split across all towers.

Suitable for providing input to for call_for_each_tower(), as in:

with distribution_strategy.scope():
  iterator = distribution_strategy.distribute_dataset(dataset)
  tower_results = distribution_strategy.call_for_each_tower(
      tower_fn, iterator.get_next())

Args:

Returns:

A Dataset iterator that will produce separate splits for each tower.

fetch

fetch(
    val,
    destination='/device:CPU:0',
    fn=(lambda x: x)
)

Return a copy of val or fn(val) on destination.

This is useful for getting a mirrored value onto a device. It will attempt to avoid a copy by checking if the value is already on the destination device.

Args:

  • val: Value (which may be mirrored) to copy.
  • destination: A device string to copy the value to.
  • fn: An optional function to apply to the value on the source device, before copying.

Returns:

A Tensor on destination.

group

group(
    value,
    name=None
)

Shortcut for tf.group(distribution.unwrap(value)).

map

map(
    map_over,
    fn,
    *args,
    **kwargs
)

non_slot_devices

non_slot_devices(var_list)

Device(s) for non-slot variables.

Create variables on these devices in a with colocate_vars_with(non_slot_devices(...)): block. Update those using update_non_slot().

Args:

  • var_list: The list of variables being optimized, needed with the default DistributionStrategy.

reduce

reduce(
    method_string,
    value,
    destinations=None
)

Combine (via e.g. sum or mean) values across towers.

Args:

  • method_string: A string indicating how to combine values, either "sum" or "mean".
  • value: A per-device value with one value per tower.
  • destinations: An optional mirrored variable, a device string, list of device strings. The return value will be copied to all destination devices (or all the devices where the mirrored variable resides). If None or unspecified, the destinations will match the devices value resides on.

Returns:

A value mirrored to destinations.

scope

scope()

Returns a context manager selecting this DistributionStrategy as current.

Inside a with distribution_strategy.scope(): code block, this thread will use a variable creator set by distribution_strategy, and will enter its "cross-tower context".

Returns:

A context manager.

tower_local_var_scope

tower_local_var_scope(reduce_method)

Inside this scope, new variables will not be mirrored.

There will still be one component variable per tower, but there is no requirement that they stay in sync. Instead, when saving them or calling fetch(), we use the value that results when calling reduce() on all the towers' variables.

Note: tower-local implies not trainable. Instead, it is expected that each tower will directly update (using assign_add() or whatever) its local variable instance but only the aggregated value (accessible using fetch()) will be exported from the model. When it is acceptable to only aggregate on export, we greatly reduce communication overhead by using tower-local variables.
Note: All component variables will be initialized to the same value, using the initialization expression from the first tower. The values will match even if the initialization expression uses random numbers.

Args:

  • reduce_method: String used as a method_string to reduce() to get the value to save when checkpointing.

Returns:

A context manager.

unwrap

unwrap(value)

Returns the list of all per-device values contained in value.

Args:

  • value: A value returned by call_for_each_tower() or a variable created in scope().

Returns:

A list of values contained in value. If value represents a single value, this returns [value].

update

update(
    var,
    fn,
    *args,
    **kwargs
)

Run fn to update var using inputs mirrored to the same devices.

If var is mirrored across multiple devices, then this implements logic like:

results = {}
for device, v in var:
  with tf.device(device):
    # *args and **kwargs will be unwrapped if they are mirrored.
    results[device] = fn(v, *args, **kwargs)
return merged(results)

Otherwise this returns fn(var, *args, **kwargs) colocated with var.'

Neither args nor *kwargs may contain per-device values. If they contain mirrored values, they will be unwrapped before calling fn.

Args:

  • var: Variable, possibly mirrored to multiple devices, to operate on.
  • fn: Function to call. Should take the variable as the first argument.
  • *args: Additional positional arguments to pass to fn().
  • **kwargs: Keyword arguments to pass to fn().

Returns:

Merged return value of fn across all towers.

update_non_slot

update_non_slot(
    colocate_with,
    fn,
    *args,
    **kwargs
)

Runs fn(*args, **kwargs) on colocate_with devices.

Args:

  • colocate_with: The return value of non_slot_devices().
  • fn: Function to execute.
  • *args: Positional arguments to pass to fn().
  • **kwargs: Keyword arguments to pass to fn().

Returns:

Return value of fn, possibly merged across devices.

© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/MirroredStrategy