W3cubDocs

/TensorFlow Python

tf.contrib.gan.losses.wargs.combine_adversarial_loss

tf.contrib.gan.losses.wargs.combine_adversarial_loss(
    main_loss,
    adversarial_loss,
    weight_factor=None,
    gradient_ratio=None,
    gradient_ratio_epsilon=1e-06,
    variables=None,
    scalar_summaries=True,
    gradient_summaries=True,
    scope=None
)

Defined in tensorflow/contrib/gan/python/losses/python/losses_impl.py.

Utility to combine main and adversarial losses.

This utility combines the main and adversarial losses in one of two ways. 1) Fixed coefficient on adversarial loss. Use weight_factor in this case. 2) Fixed ratio of gradients. Use gradient_ratio in this case. This is often used to make sure both losses affect weights roughly equally, as in https://arxiv.org/pdf/1705.05823.

One can optionally also visualize the scalar and gradient behavior of the losses.

Args:

  • main_loss: A floating scalar Tensor indicating the main loss.
  • adversarial_loss: A floating scalar Tensor indication the adversarial loss.
  • weight_factor: If not None, the coefficient by which to multiply the adversarial loss. Exactly one of this and gradient_ratio must be non-None.
  • gradient_ratio: If not None, the ratio of the magnitude of the gradients. Specifically, gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss) Exactly one of this and weight_factor must be non-None.
  • gradient_ratio_epsilon: An epsilon to add to the adversarial loss coefficient denominator, to avoid division-by-zero.
  • variables: List of variables to calculate gradients with respect to. If not present, defaults to all trainable variables.
  • scalar_summaries: Create scalar summaries of losses.
  • gradient_summaries: Create gradient summaries of losses.
  • scope: Optional name scope.

Returns:

A floating scalar Tensor indicating the desired combined loss.

Raises:

  • ValueError: Malformed input.

© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/gan/losses/wargs/combine_adversarial_loss