View source on GitHub |
Creates a new tf.estimator.Estimator
which has given metrics.
tf.estimator.add_metrics( estimator, metric_fn )
def my_auc(labels, predictions): auc_metric = tf.keras.metrics.AUC(name="my_auc") auc_metric.update_state(y_true=labels, y_pred=predictions['logistic']) return {'auc': auc_metric} estimator = tf.estimator.DNNClassifier(...) estimator = tf.estimator.add_metrics(estimator, my_auc) estimator.train(...) estimator.evaluate(...)
Example usage of custom metric which uses features:
def my_auc(labels, predictions, features): auc_metric = tf.keras.metrics.AUC(name="my_auc") auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'], sample_weight=features['weight']) return {'auc': auc_metric} estimator = tf.estimator.DNNClassifier(...) estimator = tf.estimator.add_metrics(estimator, my_auc) estimator.train(...) estimator.evaluate(...)
Args | |
---|---|
estimator | A tf.estimator.Estimator object. |
metric_fn | A function which should obey the following signature:
|
Returns | |
---|---|
A new tf.estimator.Estimator which has a union of original metrics with given ones. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/estimator/add_metrics