from itertools import product import numpy as np from .. import confusion_matrix from ...utils import check_matplotlib_support from ...utils.validation import _deprecate_positional_args from ...base import is_classifier class ConfusionMatrixDisplay: """Confusion Matrix visualization. It is recommend to use :func:`~sklearn.metrics.plot_confusion_matrix` to create a :class:`ConfusionMatrixDisplay`. All parameters are stored as attributes. Read more in the :ref:`User Guide `. Parameters ---------- confusion_matrix : ndarray of shape (n_classes, n_classes) Confusion matrix. display_labels : ndarray of shape (n_classes,), default=None Display labels for plot. If None, display labels are set from 0 to `n_classes - 1`. Attributes ---------- im_ : matplotlib AxesImage Image representing the confusion matrix. text_ : ndarray of shape (n_classes, n_classes), dtype=matplotlib Text, \ or None Array of matplotlib axes. `None` if `include_values` is false. ax_ : matplotlib Axes Axes with confusion matrix. figure_ : matplotlib Figure Figure containing the confusion matrix. """ def __init__(self, confusion_matrix, *, display_labels=None): self.confusion_matrix = confusion_matrix self.display_labels = display_labels @_deprecate_positional_args def plot(self, *, include_values=True, cmap='viridis', xticks_rotation='horizontal', values_format=None, ax=None): """Plot visualization. Parameters ---------- include_values : bool, default=True Includes values in confusion matrix. cmap : str or matplotlib Colormap, default='viridis' Colormap recognized by matplotlib. xticks_rotation : {'vertical', 'horizontal'} or float, \ default='horizontal' Rotation of xtick labels. values_format : str, default=None Format specification for values in confusion matrix. If `None`, the format specification is 'd' or '.2g' whichever is shorter. ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. Returns ------- display : :class:`~sklearn.metrics.ConfusionMatrixDisplay` """ check_matplotlib_support("ConfusionMatrixDisplay.plot") import matplotlib.pyplot as plt if ax is None: fig, ax = plt.subplots() else: fig = ax.figure cm = self.confusion_matrix n_classes = cm.shape[0] self.im_ = ax.imshow(cm, interpolation='nearest', cmap=cmap) self.text_ = None cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(256) if include_values: self.text_ = np.empty_like(cm, dtype=object) # print text with appropriate color depending on background thresh = (cm.max() + cm.min()) / 2.0 for i, j in product(range(n_classes), range(n_classes)): color = cmap_max if cm[i, j] < thresh else cmap_min if values_format is None: text_cm = format(cm[i, j], '.2g') if cm.dtype.kind != 'f': text_d = format(cm[i, j], 'd') if len(text_d) < len(text_cm): text_cm = text_d else: text_cm = format(cm[i, j], values_format) self.text_[i, j] = ax.text( j, i, text_cm, ha="center", va="center", color=color) if self.display_labels is None: display_labels = np.arange(n_classes) else: display_labels = self.display_labels fig.colorbar(self.im_, ax=ax) ax.set(xticks=np.arange(n_classes), yticks=np.arange(n_classes), xticklabels=display_labels, yticklabels=display_labels, ylabel="True label", xlabel="Predicted label") ax.set_ylim((n_classes - 0.5, -0.5)) plt.setp(ax.get_xticklabels(), rotation=xticks_rotation) self.figure_ = fig self.ax_ = ax return self @_deprecate_positional_args def plot_confusion_matrix(estimator, X, y_true, *, labels=None, sample_weight=None, normalize=None, display_labels=None, include_values=True, xticks_rotation='horizontal', values_format=None, cmap='viridis', ax=None): """Plot Confusion Matrix. Read more in the :ref:`User Guide `. Parameters ---------- estimator : estimator instance Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` in which the last estimator is a classifier. X : {array-like, sparse matrix} of shape (n_samples, n_features) Input values. y : array-like of shape (n_samples,) Target values. labels : array-like of shape (n_classes,), default=None List of labels to index the matrix. This may be used to reorder or select a subset of labels. If `None` is given, those that appear at least once in `y_true` or `y_pred` are used in sorted order. sample_weight : array-like of shape (n_samples,), default=None Sample weights. normalize : {'true', 'pred', 'all'}, default=None Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. If None, confusion matrix will not be normalized. display_labels : array-like of shape (n_classes,), default=None Target names used for plotting. By default, `labels` will be used if it is defined, otherwise the unique labels of `y_true` and `y_pred` will be used. include_values : bool, default=True Includes values in confusion matrix. xticks_rotation : {'vertical', 'horizontal'} or float, \ default='horizontal' Rotation of xtick labels. values_format : str, default=None Format specification for values in confusion matrix. If `None`, the format specification is 'd' or '.2g' whichever is shorter. cmap : str or matplotlib Colormap, default='viridis' Colormap recognized by matplotlib. ax : matplotlib Axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. Returns ------- display : :class:`~sklearn.metrics.ConfusionMatrixDisplay` Examples -------- >>> import matplotlib.pyplot as plt # doctest: +SKIP >>> from sklearn.datasets import make_classification >>> from sklearn.metrics import plot_confusion_matrix >>> from sklearn.model_selection import train_test_split >>> from sklearn.svm import SVC >>> X, y = make_classification(random_state=0) >>> X_train, X_test, y_train, y_test = train_test_split( ... X, y, random_state=0) >>> clf = SVC(random_state=0) >>> clf.fit(X_train, y_train) SVC(random_state=0) >>> plot_confusion_matrix(clf, X_test, y_test) # doctest: +SKIP >>> plt.show() # doctest: +SKIP """ check_matplotlib_support("plot_confusion_matrix") if not is_classifier(estimator): raise ValueError("plot_confusion_matrix only supports classifiers") y_pred = estimator.predict(X) cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight, labels=labels, normalize=normalize) if display_labels is None: if labels is None: display_labels = estimator.classes_ else: display_labels = labels disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_labels) return disp.plot(include_values=include_values, cmap=cmap, ax=ax, xticks_rotation=xticks_rotation, values_format=values_format)