tf.contrib.opt.AGNCustomGetter

View source on GitHub

Class AGNCustomGetter

Custom_getter class is used to do:

  1. Change trainable variables to local collection and place them at worker device
  2. Generate global variables(global center variables)
  3. Generate grad variables(gradients) which record the gradients sum and place them at worker device Notice that the class should be used with tf.replica_device_setter, so that the global center variables and global step variable can be placed at ps device.

__init__

View source

__init__(worker_device)

Args: worker_device: put the grad_variables on worker device

Methods

tf.contrib.opt.AGNCustomGetter.__call__

View source

__call__(
    getter,
    name,
    trainable,
    collections,
    *args,
    **kwargs
)

Call self as a function.