GradientTape
tf.GradientTape
tf.contrib.eager.GradientTape
Defined in tensorflow/python/eager/backprop.py
.
Record operations for automatic differentiation.
Operations are recorded if they are executed within this context manager and at least one of their inputs is being "watched".
Trainable variables (created by tf.contrib.eager.Variable
or tf.get_variable
, trainable=True is default in both cases) are automatically watched. Tensors can be manually watched by invoking the watch
method on this context manager.
For example, consider the function y = x * x
. The gradient at x = 3.0
can be computed as:
x = tf.constant(3.) with tfe.GradientTape() as g: g.watch(x) y = x * x grad = g.gradient(y, [x])[0] # Will compute to 6.0
GradientTapes can be nested to compute higher-order derivatives. For example,
x = tf.constant(3.0) with tfe.GradientTape() as g: with tfe.GradientTape() as gg: gg.watch(x) y = x * x dy_dx = gg.gradient(y, [x])[0] # Will compute to 6.0 d2y_dx2 = g.gradient(dy_dx, [x])[0] # Will compute to 2.0
By default, the resources held by a GradientTape are released as soon as GradientTape.gradient() method is called. To compute multiple gradients over the same computation, create a persistent gradient tape. This allows multiple calls to the gradient() method as resources are released when the tape object is garbage collected. For example:
x = tf.constant(3.0) with tfe.GradientTape(persistent=True) as g: g.watch(x) y = x * x z = y * y dy_dx = g.gradient(z, [x])[0] # 6.0 dz_dx = g.gradient(y, [x])[0] # 108.0 (4*x^3 at x = 3) del g # Drop the reference to the tape ## Methods <h3 id="__init__"><code>__init__</code></h3> ``` python __init__(persistent=False)
Creates a new GradientTape.
persistent
: Boolean controlling whether a persistent gradient tape is created. False by default, which means at most one call can be made to the gradient() method on this object.__enter__
__enter__()
__exit__
__exit__( typ, value, traceback )
gradient
gradient( target, sources, output_gradients=None )
Computes the gradient using operations recorded in context of this tape.
target
: Tensor to be differentiated.sources
: a list or nested structure of Tensors or Variables. target
will be differentiated against elements in sources
.output_gradients
: a list of gradients, one for each element of target. Defaults to None.a list or nested structure of Tensors (or IndexedSlices, or None), one for each element in sources
. Returned structure is the same as the structure of sources
.
RuntimeError
: if called inside the context of the tape, or if called more than once on a non-persistent tape.watch
watch(tensor)
Ensures that tensor
is being traced by this tape.
tensor
: a Tensor or list of Tensors.watched_variables
watched_variables()
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/GradientTape