GANHead
tf.contrib.gan.estimator.GANHead
tf.contrib.gan.estimator.head.GANHead
Defined in tensorflow/contrib/gan/python/estimator/python/head_impl.py
.
Head
for a GAN.
logits_dimension
Size of the last dimension of the logits Tensor
.
Typically, logits is of shape [batch_size, logits_dimension]
.
The expected size of the logits
tensor.
name
The name of this head.
A string.
__init__
__init__( generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=None, name=None )
Head
for GAN training.
generator_loss_fn
: A TFGAN loss function for the generator. Takes a GANModel
and returns a scalar.discriminator_loss_fn
: Same as generator_loss_fn
, but for the discriminator.generator_optimizer
: The optimizer for generator updates.discriminator_optimizer
: Same as generator_optimizer
, but for the discriminator updates.use_loss_summaries
: If True
, add loss summaries. If False
, does not. If None
, uses defaults.get_hooks_fn
: A function that takes a GANTrainOps tuple and returns a list of hooks. Defaults to train.get_sequential_train_hooks()
name
: name of the head. If provided, summary and metrics keys will be suffixed by "/" + name
.create_estimator_spec
create_estimator_spec( features, mode, logits, labels=None, train_op_fn=tf.contrib.gan.gan_train_ops )
Returns EstimatorSpec
that a model_fn can return.
See Head
for more details.
features
: Must be None
.mode
: Estimator's ModeKeys
.logits
: A GANModel tuple.labels
: Must be None
.train_op_fn
: Function that takes a GANModel, GANLoss, generator optimizer, and discriminator optimizer, and returns a GANTrainOps
tuple. For example, this function can come from TFGAN's train.py
library, or can be custom.EstimatorSpec
.
ValueError
: If features
isn't None
.ValueError
: If train_op_fn
isn't provided in train mode.create_loss
create_loss( features, mode, logits, labels )
Returns a GANLoss tuple from the provided GANModel.
See Head
for more details.
features
: Input dict
of Tensor
objects. Unused.mode
: Estimator's ModeKeys
.logits
: A GANModel tuple.labels
: Must be None
.A GANLoss tuple.
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/gan/estimator/GANHead