View source on GitHub |
Computes the confusion matrix from predictions and labels.
tf.math.confusion_matrix( labels, predictions, num_classes=None, dtype=tf.dtypes.int32, name=None, weights=None )
The matrix columns represent the prediction labels and the rows represent the real labels. The confusion matrix is always a 2-D array of shape [n, n]
, where n
is the number of valid labels for a given classification task. Both prediction and labels must be 1-D arrays of the same shape in order for this function to work.
If num_classes
is None
, then num_classes
will be set to one plus the maximum value in either predictions or labels. Class labels are expected to start at 0. For example, if num_classes
is 3, then the possible labels would be [0, 1, 2]
.
If weights
is not None
, then each prediction contributes its corresponding weight to the total value of the confusion matrix cell.
tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> [[0 0 0 0 0] [0 0 1 0 0] [0 0 1 0 0] [0 0 0 0 0] [0 0 0 0 1]]
Note that the possible labels are assumed to be [0, 1, 2, 3, 4]
, resulting in a 5x5 confusion matrix.
Args | |
---|---|
labels | 1-D Tensor of real labels for the classification task. |
predictions | 1-D Tensor of predictions for a given classification. |
num_classes | The possible number of labels the classification task can have. If this value is not provided, it will be calculated using both predictions and labels array. |
dtype | Data type of the confusion matrix. |
name | Scope name. |
weights | An optional Tensor whose shape matches predictions . |
Returns | |
---|---|
A Tensor of type dtype with shape [n, n] representing the confusion matrix, where n is the number of possible labels in the classification task. |
Raises | |
---|---|
ValueError | If both predictions and labels are not 1-D vectors and have mismatched shapes, or if weights is not None and its shape doesn't match predictions . |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/math/confusion_matrix