VBN
Defined in tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
.
A class to perform virtual batch normalization.
This technique was first introduced in Improved Techniques for Training GANs
(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch normalization on a minibatch, it fixes a reference subset of the data to use for calculating normalization statistics.
To do this, we calculate the reference batch mean and mean square, and modify those statistics for each example. We use mean square instead of variance, since it is linear.
Note that if center
or scale
variables are created, they are shared between all calls to this object.
The __init__
API is intended to mimic tf.layers.batch_normalization
as closely as possible.
__init__
__init__( reference_batch, axis=-1, epsilon=0.001, center=True, scale=True, beta_initializer=tf.zeros_initializer(), gamma_initializer=tf.ones_initializer(), beta_regularizer=None, gamma_regularizer=None, trainable=True, name=None, batch_axis=0 )
Initialize virtual batch normalization object.
We precompute the 'mean' and 'mean squared' of the reference batch, so that __call__
is efficient. This means that the axis must be supplied when the object is created, not when it is called.
We precompute 'square mean' instead of 'variance', because the square mean can be easily adjusted on a per-example basis.
reference_batch
: A minibatch tensors. This will form the reference data from which the normalization statistics are calculated. See https://arxiv.org/abs/1606.03498 for more details.axis
: Integer, the axis that should be normalized (typically the features axis). For instance, after a Convolution2D
layer with data_format="channels_first"
, set axis=1
in BatchNormalization
.epsilon
: Small float added to variance to avoid dividing by zero.center
: If True, add offset of beta
to normalized tensor. If False, beta
is ignored.scale
: If True, multiply by gamma
. If False, gamma
is not used. When the next layer is linear (also e.g. nn.relu
), this can be disabled since the scaling can be done by the next layer.beta_initializer
: Initializer for the beta weight.gamma_initializer
: Initializer for the gamma weight.beta_regularizer
: Optional regularizer for the beta weight.gamma_regularizer
: Optional regularizer for the gamma weight.trainable
: Boolean, if True
also add variables to the graph collection GraphKeys.TRAINABLE_VARIABLES
(see tf.Variable).name
: String, the name of the ops.batch_axis
: The axis of the batch dimension. This dimension is treated differently in virtual batch normalization
vs batch normalization
.ValueError
: If reference_batch
has unknown dimensions at graph construction.ValueError
: If batch_axis
is the same as axis
.__call__
__call__(inputs)
Run virtual batch normalization on inputs.
inputs
: Tensor input.A virtual batch normalized version of inputs
.
ValueError
: If inputs
shape isn't compatible with the reference batch.reference_batch_normalization
reference_batch_normalization()
Return the reference batch, but batch normalized.
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/gan/features/VBN