W3cubDocs

/TensorFlow 2.4

tf.compat.v1.confusion_matrix

Computes the confusion matrix from predictions and labels.

The matrix columns represent the prediction labels and the rows represent the real labels. The confusion matrix is always a 2-D array of shape [n, n], where n is the number of valid labels for a given classification task. Both prediction and labels must be 1-D arrays of the same shape in order for this function to work.

If num_classes is None, then num_classes will be set to one plus the maximum value in either predictions or labels. Class labels are expected to start at 0. For example, if num_classes is 3, then the possible labels would be [0, 1, 2].

If weights is not None, then each prediction contributes its corresponding weight to the total value of the confusion matrix cell.

For example:

tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
    [[0 0 0 0 0]
     [0 0 1 0 0]
     [0 0 1 0 0]
     [0 0 0 0 0]
     [0 0 0 0 1]]

Note that the possible labels are assumed to be [0, 1, 2, 3, 4], resulting in a 5x5 confusion matrix.

Args
labels 1-D Tensor of real labels for the classification task.
predictions 1-D Tensor of predictions for a given classification.
num_classes The possible number of labels the classification task can have. If this value is not provided, it will be calculated using both predictions and labels array.
dtype Data type of the confusion matrix.
name Scope name.
weights An optional Tensor whose shape matches predictions.
Returns
A Tensor of type dtype with shape [n, n] representing the confusion matrix, where n is the number of possible labels in the classification task.
Raises
ValueError If both predictions and labels are not 1-D vectors and have mismatched shapes, or if weights is not None and its shape doesn't match predictions.

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/compat/v1/confusion_matrix