A class wrapping information needed by a distribute function.
tf.distribute.experimental.ValueContext( replica_id_in_sync_group=0, num_replicas_in_sync=1 )
This is a context class that is passed to the value_fn
in strategy.experimental_distribute_values_from_function
and contains information about the compute replicas. The num_replicas_in_sync
and replica_id
can be used to customize the value on each replica.
def value_fn(context): return context.replica_id_in_sync_group/context.num_replicas_in_sync context = tf.distribute.experimental.ValueContext( replica_id_in_sync_group=2, num_replicas_in_sync=4) per_replica_value = value_fn(context) per_replica_value 0.5
experimental_distribute_values_from_function
.strategy = tf.distribute.MirroredStrategy() def value_fn(value_context): return value_context.num_replicas_in_sync distributed_values = ( strategy.experimental_distribute_values_from_function( value_fn)) local_result = strategy.experimental_local_results(distributed_values) local_result (1,)
Args | |
---|---|
replica_id_in_sync_group | the current replica_id, should be an int in [0,num_replicas_in_sync ). |
num_replicas_in_sync | the number of replicas that are in sync. |
Attributes | |
---|---|
num_replicas_in_sync | Returns the number of compute replicas in sync. |
replica_id_in_sync_group | Returns the replica ID. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/experimental/ValueContext