View source on GitHub |
Returns the current tf.distribute.ReplicaContext
or None
.
tf.distribute.get_replica_context()
Returns None
if in a cross-replica context.
ReplicaContext
object);None
) when entering a with tf.distribute.Strategy.scope():
block;strategy.run(fn, ...)
;fn
calls get_replica_context().merge_call(merge_fn, ...)
, then inside merge_fn
you are back in the cross-replica context (and again this function will return None
).Most tf.distribute.Strategy
methods may only be executed in a cross-replica context, in a replica context you should use the API of the tf.distribute.ReplicaContext
object returned by this method instead.
assert tf.distribute.get_replica_context() is not None # default with strategy.scope(): assert tf.distribute.get_replica_context() is None def f(): replica_context = tf.distribute.get_replica_context() # for strategy assert replica_context is not None tf.print("Replica id: ", replica_context.replica_id_in_sync_group, " of ", replica_context.num_replicas_in_sync) strategy.run(f)
Returns | |
---|---|
The current tf.distribute.ReplicaContext object when in a replica context scope, else None . Within a particular block, exactly one of these two things will be true:
|
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/distribute/get_replica_context