TowerContext
Defined in tensorflow/python/training/distribute.py
.
DistributionStrategy API inside a call_for_each_tower()
call.
device
The device this tower is to be executed on, as a string.
distribution_strategy
The current DistributionStrategy
object.
is_single_tower
Returns whether there is a single tower or multiple.
num_towers
Returns number of towers, for purposes of averaging across towers.
tower_id
Which tower is being defined, a number from 0 to num_towers - 1
.
__init__
__init__( distribution_strategy, tower_id )
Initialize self. See help(type(self)) for accurate signature.
__enter__
__enter__()
__exit__
__exit__( exception_type, exception_value, traceback )
merge_call
merge_call( merge_fn, *args, **kwargs )
Merge args across towers and run merge_fn
in a cross-tower context.
This allows communication and coordination when there are multiple calls to a model function triggered by a call to distribution.call_for_each_tower(model_fn, ...)
.
See MirroredDistribution.call_for_each_tower()
for an explanation.
Otherwise, this is equivalent to:
distribution = get_distribution_strategy() with cross-tower-context(distribution): return merge_fn(distribution, *args, **kwargs)
merge_fn
: function that joins arguments from threads that are given as PerDevice. It accepts DistributionStrategy
object as the first argument.*args
: positional per-thread arguments for merge_fn
**kwargs
: keyword per-thread arguments for merge_fn
.The return value of merge_fn
, except for PerDevice
values which are unpacked.
tower_local_var_scope
tower_local_var_scope(reduce_method)
Alias for distribution_strategy.tower_local_var_scope().
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/TowerContext