Decorator to register a KL divergence implementation function.
tf.compat.v1.distributions.RegisterKL( dist_cls_a, dist_cls_b )
@distributions.RegisterKL(distributions.Normal, distributions.Normal) def _kl_normal_mvn(norm_a, norm_b): # Return KL(norm_a || norm_b)
Args | |
---|---|
dist_cls_a | the class of the first argument of the KL divergence. |
dist_cls_b | the class of the second argument of the KL divergence. |
__call__
__call__( kl_fn )
Perform the KL registration.
Args | |
---|---|
kl_fn | The function to use for the KL divergence. |
Returns | |
---|---|
kl_fn |
Raises | |
---|---|
TypeError | if kl_fn is not a callable. |
ValueError | if a KL divergence function has already been registered for the given argument classes. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/compat/v1/distributions/RegisterKL