tf.contrib.eager.custom_gradient
tf.custom_gradient
tf.custom_gradient(f)
Defined in tensorflow/python/ops/custom_gradient.py
.
Decorator to define a function with a custom gradient.
This decorator allows fine grained control over the gradients of a sequence for operations. This may be useful for multiple reasons, including providing a more efficient or numerically stable gradient for a sequence of operations.
For example, consider the following function that commonly occurs in the computation of cross entropy and log likelihoods:
def log1pexp(x): return tf.log(1 + tf.exp(x))
Due to numerical instability, the gradient this function evaluated at x=100 is NaN. For example:
x = tf.constant(100.) y = log1pexp(x) dy = tf.gradients(y, x) # Will be NaN when evaluated.
The gradient expression can be analytically simplified to provide numerical stability:
@tf.custom_gradient def log1pexp(x): e = tf.exp(x) def grad(dy): return dy * (1 - 1 / (1 + e)) return tf.log(1 + e), grad
With this definition, the gradient at x=100 will be correctly evaluated as 1.0.
See also tf.RegisterGradient
which registers a gradient function for a primitive TensorFlow operation. tf.custom_gradient
on the other hand allows for fine grained control over the gradient computation of a sequence of operations.
f
: function f(x)
that returns a tuple (y, grad_fn)
where:x
is a Tensor
or sequence of Tensor
inputs to the function.y
is a Tensor
or sequence of Tensor
outputs of applying TensorFlow operations in f
to x
.grad_fn
is a function with the signature g(grad_ys)
which returns a list of Tensor
s - the derivatives of Tensor
s in y
with respect to the Tensor
s in x.
grad_ysis a
Tensoror sequence of
Tensors the same size as
yholding the initial value gradients for each
Tensorin
y`.A function h(x)
which returns the same value as f(x)[0]
and whose gradient (as calculated by tf.gradients
) is determined by f(x)[1]
.
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/custom_gradient