W3cubDocs

/TensorFlow Python

tf.contrib.opt.ModelAverageCustomGetter

Class ModelAverageCustomGetter

Defined in tensorflow/contrib/opt/python/training/model_average_optimizer.py.

Custom_getter class is used to do.

  1. Change trainable variables to local collection and place them at worker device
  2. Generate global variables Notice that the class should be used with tf.replica_device_setter, so that the global center variables and global step variable can be placed at ps device. Besides, use 'tf.get_variable' instead of 'tf.Variable' to use this custom getter.

For example, ma_custom_getter = ModelAverageCustomGetter(worker_device) with tf.device( tf.train.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster)), tf.variable_scope('',custom_getter=ma_custom_getter): hid_w = tf.get_variable( initializer=tf.truncated_normal( [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w") hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]), name="hid_b")

Methods

__init__

__init__(worker_device)

Create a new ElasticAverageCustomGetter.

Args:

  • worker_device: String. Name of the worker job.

__call__

__call__(
    getter,
    name,
    trainable,
    collections,
    *args,
    **kwargs
)

Call self as a function.

© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/opt/ModelAverageCustomGetter