GANHead
tf.contrib.gan.estimator.GANHead
tf.contrib.gan.estimator.head.GANHead
Defined in tensorflow/contrib/gan/python/estimator/python/head_impl.py.
Head for a GAN.
logits_dimensionSize of the last dimension of the logits Tensor.
Typically, logits is of shape [batch_size, logits_dimension].
The expected size of the logits tensor.
nameThe name of this head.
A string.
__init____init__(
generator_loss_fn,
discriminator_loss_fn,
generator_optimizer,
discriminator_optimizer,
use_loss_summaries=True,
get_hooks_fn=None,
name=None
)
Head for GAN training.
generator_loss_fn: A TFGAN loss function for the generator. Takes a GANModel and returns a scalar.discriminator_loss_fn: Same as generator_loss_fn, but for the discriminator.generator_optimizer: The optimizer for generator updates.discriminator_optimizer: Same as generator_optimizer, but for the discriminator updates.use_loss_summaries: If True, add loss summaries. If False, does not. If None, uses defaults.get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list of hooks. Defaults to train.get_sequential_train_hooks()
name: name of the head. If provided, summary and metrics keys will be suffixed by "/" + name.create_estimator_speccreate_estimator_spec(
features,
mode,
logits,
labels=None,
train_op_fn=tf.contrib.gan.gan_train_ops
)
Returns EstimatorSpec that a model_fn can return.
See Head for more details.
features: Must be None.mode: Estimator's ModeKeys.logits: A GANModel tuple.labels: Must be None.train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer, and discriminator optimizer, and returns a GANTrainOps tuple. For example, this function can come from TFGAN's train.py library, or can be custom.EstimatorSpec.
ValueError: If features isn't None.ValueError: If train_op_fn isn't provided in train mode.create_losscreate_loss(
features,
mode,
logits,
labels
)
Returns a GANLoss tuple from the provided GANModel.
See Head for more details.
features: Input dict of Tensor objects. Unused.mode: Estimator's ModeKeys.logits: A GANModel tuple.labels: Must be None.A GANLoss tuple.
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/gan/estimator/GANHead