tf.contrib.gan.losses.wargs.wasserstein_gradient_penalty( real_data, generated_data, generator_inputs, discriminator_fn, discriminator_scope, epsilon=1e-10, target=1.0, one_sided=False, weights=1.0, scope=None, loss_collection=tf.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False )
Defined in tensorflow/contrib/gan/python/losses/python/losses_impl.py
.
The gradient penalty for the Wasserstein discriminator loss.
See Improved Training of Wasserstein GANs
(https://arxiv.org/abs/1704.00028) for more details.
real_data
: Real data.generated_data
: Output of the generator.generator_inputs
: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator.discriminator_fn
: A discriminator function that conforms to TFGAN API.discriminator_scope
: If not None
, reuse discriminators from this scope.epsilon
: A small positive number added for numerical stability when computing the gradient norm.target
: Optional Python number or Tensor
indicating the target value of gradient norm. Defaults to 1.0.one_sided
: If True
, penalty proposed in https://arxiv.org/abs/1709.08894 is used. Defaults to False
.weights
: Optional Tensor
whose rank is either 0, or the same rank as real_data
and generated_data
, and must be broadcastable to them (i.e., all dimensions must be either 1
, or the same as the corresponding dimension).scope
: The scope for the operations performed in computing the loss.loss_collection
: collection to which this loss will be added.reduction
: A tf.losses.Reduction
to apply to loss.add_summaries
: Whether or not to add summaries for the loss.A loss Tensor. The shape depends on reduction
.
ValueError
: If the rank of data Tensors is unknown.
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/gan/losses/wargs/wasserstein_gradient_penalty