View source on GitHub |
Decorator to define a function with a custom gradient.
tf.custom_gradient( f=None )
This decorator allows fine grained control over the gradients of a sequence for operations. This may be useful for multiple reasons, including providing a more efficient or numerically stable gradient for a sequence of operations.
For example, consider the following function that commonly occurs in the computation of cross entropy and log likelihoods:
def log1pexp(x): return tf.math.log(1 + tf.exp(x))
Due to numerical instability, the gradient of this function evaluated at x=100 is NaN. For example:
x = tf.constant(100.) y = log1pexp(x) dy = tf.gradients(y, x) # Will be NaN when evaluated.
The gradient expression can be analytically simplified to provide numerical stability:
@tf.custom_gradient def log1pexp(x): e = tf.exp(x) def grad(dy): return dy * (1 - 1 / (1 + e)) return tf.math.log(1 + e), grad
With this definition, the gradient at x=100 will be correctly evaluated as 1.0.
Nesting custom gradients can lead to unintuitive results. The default behavior does not correspond to n-th order derivatives. For example
@tf.custom_gradient def op(x): y = op1(x) @tf.custom_gradient def grad_fn(dy): gdy = op2(x, y, dy) def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x. return op3(x, y, dy, ddy) return gdy, grad_grad_fn return y, grad_fn
The function grad_grad_fn
will be calculating the first order gradient of grad_fn
with respect to dy
, which is used to generate forward-mode gradient graphs from backward-mode gradient graphs, but is not the same as the second order gradient of op
with respect to x
.
Instead, wrap nested @tf.custom_gradients
in another function:
@tf.custom_gradient def op_with_fused_backprop(x): y, x_grad = fused_op(x) def first_order_gradient(dy): @tf.custom_gradient def first_order_custom(unused_x): def second_order_and_transpose(ddy): return second_order_for_x(...), gradient_wrt_dy(...) return x_grad, second_order_and_transpose return dy * first_order_custom(x) return y, first_order_gradient
Additional arguments to the inner @tf.custom_gradient
-decorated function control the expected return values of the innermost function.
See also tf.RegisterGradient
which registers a gradient function for a primitive TensorFlow operation. tf.custom_gradient
on the other hand allows for fine grained control over the gradient computation of a sequence of operations.
Note that if the decorated function uses Variable
s, the enclosing variable scope must be using ResourceVariable
s.
Args | |
---|---|
f | function f(*x) that returns a tuple (y, grad_fn) where:
In a pure mathematical sense, a vector-argument vector-valued function If |
Returns | |
---|---|
A function h(x) which returns the same value as f(x)[0] and whose gradient (as calculated by tf.gradients ) is determined by f(x)[1] . |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/custom_gradient