W3cubDocs

/TensorFlow Python

tf.contrib.gan.losses.wargs.wasserstein_gradient_penalty

tf.contrib.gan.losses.wargs.wasserstein_gradient_penalty(
    real_data,
    generated_data,
    generator_inputs,
    discriminator_fn,
    discriminator_scope,
    epsilon=1e-10,
    target=1.0,
    one_sided=False,
    weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
    add_summaries=False
)

Defined in tensorflow/contrib/gan/python/losses/python/losses_impl.py.

The gradient penalty for the Wasserstein discriminator loss.

See Improved Training of Wasserstein GANs (https://arxiv.org/abs/1704.00028) for more details.

Args:

  • real_data: Real data.
  • generated_data: Output of the generator.
  • generator_inputs: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator.
  • discriminator_fn: A discriminator function that conforms to TFGAN API.
  • discriminator_scope: If not None, reuse discriminators from this scope.
  • epsilon: A small positive number added for numerical stability when computing the gradient norm.
  • target: Optional Python number or Tensor indicating the target value of gradient norm. Defaults to 1.0.
  • one_sided: If True, penalty proposed in https://arxiv.org/abs/1709.08894 is used. Defaults to False.
  • weights: Optional Tensor whose rank is either 0, or the same rank as real_data and generated_data, and must be broadcastable to them (i.e., all dimensions must be either 1, or the same as the corresponding dimension).
  • scope: The scope for the operations performed in computing the loss.
  • loss_collection: collection to which this loss will be added.
  • reduction: A tf.losses.Reduction to apply to loss.
  • add_summaries: Whether or not to add summaries for the loss.

Returns:

A loss Tensor. The shape depends on reduction.

Raises:

  • ValueError: If the rank of data Tensors is unknown.

© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/gan/losses/wargs/wasserstein_gradient_penalty