Uploaded Test files
This commit is contained in:
parent
f584ad9d97
commit
2e81cb7d99
16627 changed files with 2065359 additions and 102444 deletions
0
venv/Lib/site-packages/sklearn/metrics/_plot/__init__.py
Normal file
0
venv/Lib/site-packages/sklearn/metrics/_plot/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
40
venv/Lib/site-packages/sklearn/metrics/_plot/base.py
Normal file
40
venv/Lib/site-packages/sklearn/metrics/_plot/base.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
def _check_classifer_response_method(estimator, response_method):
|
||||
"""Return prediction method from the response_method
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator: object
|
||||
Classifier to check
|
||||
|
||||
response_method: {'auto', 'predict_proba', 'decision_function'}
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
Returns
|
||||
-------
|
||||
prediction_method: callable
|
||||
prediction method of estimator
|
||||
"""
|
||||
|
||||
if response_method not in ("predict_proba", "decision_function", "auto"):
|
||||
raise ValueError("response_method must be 'predict_proba', "
|
||||
"'decision_function' or 'auto'")
|
||||
|
||||
error_msg = "response method {} is not defined in {}"
|
||||
if response_method != "auto":
|
||||
prediction_method = getattr(estimator, response_method, None)
|
||||
if prediction_method is None:
|
||||
raise ValueError(error_msg.format(response_method,
|
||||
estimator.__class__.__name__))
|
||||
else:
|
||||
predict_proba = getattr(estimator, 'predict_proba', None)
|
||||
decision_function = getattr(estimator, 'decision_function', None)
|
||||
prediction_method = predict_proba or decision_function
|
||||
if prediction_method is None:
|
||||
raise ValueError(error_msg.format(
|
||||
"decision_function or predict_proba",
|
||||
estimator.__class__.__name__))
|
||||
|
||||
return prediction_method
|
233
venv/Lib/site-packages/sklearn/metrics/_plot/confusion_matrix.py
Normal file
233
venv/Lib/site-packages/sklearn/metrics/_plot/confusion_matrix.py
Normal file
|
@ -0,0 +1,233 @@
|
|||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import confusion_matrix
|
||||
from ...utils import check_matplotlib_support
|
||||
from ...utils.validation import _deprecate_positional_args
|
||||
from ...base import is_classifier
|
||||
|
||||
|
||||
class ConfusionMatrixDisplay:
|
||||
"""Confusion Matrix visualization.
|
||||
|
||||
It is recommend to use :func:`~sklearn.metrics.plot_confusion_matrix` to
|
||||
create a :class:`ConfusionMatrixDisplay`. All parameters are stored as
|
||||
attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
confusion_matrix : ndarray of shape (n_classes, n_classes)
|
||||
Confusion matrix.
|
||||
|
||||
display_labels : ndarray of shape (n_classes,), default=None
|
||||
Display labels for plot. If None, display labels are set from 0 to
|
||||
`n_classes - 1`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
im_ : matplotlib AxesImage
|
||||
Image representing the confusion matrix.
|
||||
|
||||
text_ : ndarray of shape (n_classes, n_classes), dtype=matplotlib Text, \
|
||||
or None
|
||||
Array of matplotlib axes. `None` if `include_values` is false.
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with confusion matrix.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the confusion matrix.
|
||||
"""
|
||||
def __init__(self, confusion_matrix, *, display_labels=None):
|
||||
self.confusion_matrix = confusion_matrix
|
||||
self.display_labels = display_labels
|
||||
|
||||
@_deprecate_positional_args
|
||||
def plot(self, *, include_values=True, cmap='viridis',
|
||||
xticks_rotation='horizontal', values_format=None, ax=None):
|
||||
"""Plot visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
include_values : bool, default=True
|
||||
Includes values in confusion matrix.
|
||||
|
||||
cmap : str or matplotlib Colormap, default='viridis'
|
||||
Colormap recognized by matplotlib.
|
||||
|
||||
xticks_rotation : {'vertical', 'horizontal'} or float, \
|
||||
default='horizontal'
|
||||
Rotation of xtick labels.
|
||||
|
||||
values_format : str, default=None
|
||||
Format specification for values in confusion matrix. If `None`,
|
||||
the format specification is 'd' or '.2g' whichever is shorter.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
|
||||
"""
|
||||
check_matplotlib_support("ConfusionMatrixDisplay.plot")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
else:
|
||||
fig = ax.figure
|
||||
|
||||
cm = self.confusion_matrix
|
||||
n_classes = cm.shape[0]
|
||||
self.im_ = ax.imshow(cm, interpolation='nearest', cmap=cmap)
|
||||
self.text_ = None
|
||||
cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(256)
|
||||
|
||||
if include_values:
|
||||
self.text_ = np.empty_like(cm, dtype=object)
|
||||
|
||||
# print text with appropriate color depending on background
|
||||
thresh = (cm.max() + cm.min()) / 2.0
|
||||
|
||||
for i, j in product(range(n_classes), range(n_classes)):
|
||||
color = cmap_max if cm[i, j] < thresh else cmap_min
|
||||
|
||||
if values_format is None:
|
||||
text_cm = format(cm[i, j], '.2g')
|
||||
if cm.dtype.kind != 'f':
|
||||
text_d = format(cm[i, j], 'd')
|
||||
if len(text_d) < len(text_cm):
|
||||
text_cm = text_d
|
||||
else:
|
||||
text_cm = format(cm[i, j], values_format)
|
||||
|
||||
self.text_[i, j] = ax.text(
|
||||
j, i, text_cm,
|
||||
ha="center", va="center",
|
||||
color=color)
|
||||
|
||||
if self.display_labels is None:
|
||||
display_labels = np.arange(n_classes)
|
||||
else:
|
||||
display_labels = self.display_labels
|
||||
|
||||
fig.colorbar(self.im_, ax=ax)
|
||||
ax.set(xticks=np.arange(n_classes),
|
||||
yticks=np.arange(n_classes),
|
||||
xticklabels=display_labels,
|
||||
yticklabels=display_labels,
|
||||
ylabel="True label",
|
||||
xlabel="Predicted label")
|
||||
|
||||
ax.set_ylim((n_classes - 0.5, -0.5))
|
||||
plt.setp(ax.get_xticklabels(), rotation=xticks_rotation)
|
||||
|
||||
self.figure_ = fig
|
||||
self.ax_ = ax
|
||||
return self
|
||||
|
||||
|
||||
@_deprecate_positional_args
|
||||
def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
|
||||
sample_weight=None, normalize=None,
|
||||
display_labels=None, include_values=True,
|
||||
xticks_rotation='horizontal',
|
||||
values_format=None,
|
||||
cmap='viridis', ax=None):
|
||||
"""Plot Confusion Matrix.
|
||||
|
||||
Read more in the :ref:`User Guide <confusion_matrix>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
labels : array-like of shape (n_classes,), default=None
|
||||
List of labels to index the matrix. This may be used to reorder or
|
||||
select a subset of labels. If `None` is given, those that appear at
|
||||
least once in `y_true` or `y_pred` are used in sorted order.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
normalize : {'true', 'pred', 'all'}, default=None
|
||||
Normalizes confusion matrix over the true (rows), predicted (columns)
|
||||
conditions or all the population. If None, confusion matrix will not be
|
||||
normalized.
|
||||
|
||||
display_labels : array-like of shape (n_classes,), default=None
|
||||
Target names used for plotting. By default, `labels` will be used if
|
||||
it is defined, otherwise the unique labels of `y_true` and `y_pred`
|
||||
will be used.
|
||||
|
||||
include_values : bool, default=True
|
||||
Includes values in confusion matrix.
|
||||
|
||||
xticks_rotation : {'vertical', 'horizontal'} or float, \
|
||||
default='horizontal'
|
||||
Rotation of xtick labels.
|
||||
|
||||
values_format : str, default=None
|
||||
Format specification for values in confusion matrix. If `None`,
|
||||
the format specification is 'd' or '.2g' whichever is shorter.
|
||||
|
||||
cmap : str or matplotlib Colormap, default='viridis'
|
||||
Colormap recognized by matplotlib.
|
||||
|
||||
ax : matplotlib Axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt # doctest: +SKIP
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import plot_confusion_matrix
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> plot_confusion_matrix(clf, X_test, y_test) # doctest: +SKIP
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
"""
|
||||
check_matplotlib_support("plot_confusion_matrix")
|
||||
|
||||
if not is_classifier(estimator):
|
||||
raise ValueError("plot_confusion_matrix only supports classifiers")
|
||||
|
||||
y_pred = estimator.predict(X)
|
||||
cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight,
|
||||
labels=labels, normalize=normalize)
|
||||
|
||||
if display_labels is None:
|
||||
if labels is None:
|
||||
display_labels = estimator.classes_
|
||||
else:
|
||||
display_labels = labels
|
||||
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
|
||||
display_labels=display_labels)
|
||||
return disp.plot(include_values=include_values,
|
||||
cmap=cmap, ax=ax, xticks_rotation=xticks_rotation,
|
||||
values_format=values_format)
|
|
@ -0,0 +1,181 @@
|
|||
from .base import _check_classifer_response_method
|
||||
|
||||
from .. import average_precision_score
|
||||
from .. import precision_recall_curve
|
||||
|
||||
from ...utils import check_matplotlib_support
|
||||
from ...utils.validation import _deprecate_positional_args
|
||||
from ...base import is_classifier
|
||||
|
||||
|
||||
class PrecisionRecallDisplay:
|
||||
"""Precision Recall visualization.
|
||||
|
||||
It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve`
|
||||
to create a visualizer. All parameters are stored as attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
precision : ndarray
|
||||
Precision values.
|
||||
|
||||
recall : ndarray
|
||||
Recall values.
|
||||
|
||||
average_precision : float, default=None
|
||||
Average precision. If None, the average precision is not shown.
|
||||
|
||||
estimator_name : str, default=None
|
||||
Name of estimator. If None, then the estimator name is not shown.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
line_ : matplotlib Artist
|
||||
Precision recall curve.
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with precision recall curve.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the curve.
|
||||
"""
|
||||
def __init__(self, precision, recall, *,
|
||||
average_precision=None, estimator_name=None):
|
||||
self.precision = precision
|
||||
self.recall = recall
|
||||
self.average_precision = average_precision
|
||||
self.estimator_name = estimator_name
|
||||
|
||||
@_deprecate_positional_args
|
||||
def plot(self, ax=None, *, name=None, **kwargs):
|
||||
"""Plot visualization.
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's `plot`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : Matplotlib Axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
name : str, default=None
|
||||
Name of precision recall curve for labeling. If `None`, use the
|
||||
name of the estimator.
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
|
||||
Object that stores computed values.
|
||||
"""
|
||||
check_matplotlib_support("PrecisionRecallDisplay.plot")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
name = self.estimator_name if name is None else name
|
||||
|
||||
line_kwargs = {"drawstyle": "steps-post"}
|
||||
if self.average_precision is not None and name is not None:
|
||||
line_kwargs["label"] = (f"{name} (AP = "
|
||||
f"{self.average_precision:0.2f})")
|
||||
elif self.average_precision is not None:
|
||||
line_kwargs["label"] = (f"AP = "
|
||||
f"{self.average_precision:0.2f}")
|
||||
elif name is not None:
|
||||
line_kwargs["label"] = name
|
||||
line_kwargs.update(**kwargs)
|
||||
|
||||
self.line_, = ax.plot(self.recall, self.precision, **line_kwargs)
|
||||
ax.set(xlabel="Recall", ylabel="Precision")
|
||||
|
||||
if "label" in line_kwargs:
|
||||
ax.legend(loc='lower left')
|
||||
|
||||
self.ax_ = ax
|
||||
self.figure_ = ax.figure
|
||||
return self
|
||||
|
||||
|
||||
@_deprecate_positional_args
|
||||
def plot_precision_recall_curve(estimator, X, y, *,
|
||||
sample_weight=None, response_method="auto",
|
||||
name=None, ax=None, **kwargs):
|
||||
"""Plot Precision Recall Curve for binary classifiers.
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's `plot`.
|
||||
|
||||
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Binary target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'}, \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
name : str, default=None
|
||||
Name for labeling curve. If `None`, the name of the
|
||||
estimator is used.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
|
||||
Object that stores computed values.
|
||||
"""
|
||||
check_matplotlib_support("plot_precision_recall_curve")
|
||||
|
||||
classification_error = ("{} should be a binary classifier".format(
|
||||
estimator.__class__.__name__))
|
||||
if not is_classifier(estimator):
|
||||
raise ValueError(classification_error)
|
||||
|
||||
prediction_method = _check_classifer_response_method(estimator,
|
||||
response_method)
|
||||
y_pred = prediction_method(X)
|
||||
|
||||
if y_pred.ndim != 1:
|
||||
if y_pred.shape[1] != 2:
|
||||
raise ValueError(classification_error)
|
||||
else:
|
||||
y_pred = y_pred[:, 1]
|
||||
|
||||
pos_label = estimator.classes_[1]
|
||||
precision, recall, _ = precision_recall_curve(y, y_pred,
|
||||
pos_label=pos_label,
|
||||
sample_weight=sample_weight)
|
||||
average_precision = average_precision_score(y, y_pred,
|
||||
pos_label=pos_label,
|
||||
sample_weight=sample_weight)
|
||||
name = name if name is not None else estimator.__class__.__name__
|
||||
viz = PrecisionRecallDisplay(
|
||||
precision=precision, recall=recall,
|
||||
average_precision=average_precision, estimator_name=name
|
||||
)
|
||||
return viz.plot(ax=ax, name=name, **kwargs)
|
203
venv/Lib/site-packages/sklearn/metrics/_plot/roc_curve.py
Normal file
203
venv/Lib/site-packages/sklearn/metrics/_plot/roc_curve.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
from .. import auc
|
||||
from .. import roc_curve
|
||||
|
||||
from .base import _check_classifer_response_method
|
||||
from ...utils import check_matplotlib_support
|
||||
from ...base import is_classifier
|
||||
from ...utils.validation import _deprecate_positional_args
|
||||
|
||||
|
||||
class RocCurveDisplay:
|
||||
"""ROC Curve visualization.
|
||||
|
||||
It is recommend to use :func:`~sklearn.metrics.plot_roc_curve` to create a
|
||||
visualizer. All parameters are stored as attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fpr : ndarray
|
||||
False positive rate.
|
||||
|
||||
tpr : ndarray
|
||||
True positive rate.
|
||||
|
||||
roc_auc : float, default=None
|
||||
Area under ROC curve. If None, the roc_auc score is not shown.
|
||||
|
||||
estimator_name : str, default=None
|
||||
Name of estimator. If None, the estimator name is not shown.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
line_ : matplotlib Artist
|
||||
ROC Curve.
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with ROC Curve.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the curve.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt # doctest: +SKIP
|
||||
>>> import numpy as np
|
||||
>>> from sklearn import metrics
|
||||
>>> y = np.array([0, 0, 1, 1])
|
||||
>>> pred = np.array([0.1, 0.4, 0.35, 0.8])
|
||||
>>> fpr, tpr, thresholds = metrics.roc_curve(y, pred)
|
||||
>>> roc_auc = metrics.auc(fpr, tpr)
|
||||
>>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,\
|
||||
estimator_name='example estimator')
|
||||
>>> display.plot() # doctest: +SKIP
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
"""
|
||||
def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None):
|
||||
self.fpr = fpr
|
||||
self.tpr = tpr
|
||||
self.roc_auc = roc_auc
|
||||
self.estimator_name = estimator_name
|
||||
|
||||
@_deprecate_positional_args
|
||||
def plot(self, ax=None, *, name=None, **kwargs):
|
||||
"""Plot visualization
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's ``plot``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
name : str, default=None
|
||||
Name of ROC Curve for labeling. If `None`, use the name of the
|
||||
estimator.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.plot.RocCurveDisplay`
|
||||
Object that stores computed values.
|
||||
"""
|
||||
check_matplotlib_support('RocCurveDisplay.plot')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
name = self.estimator_name if name is None else name
|
||||
|
||||
line_kwargs = {}
|
||||
if self.roc_auc is not None and name is not None:
|
||||
line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})"
|
||||
elif self.roc_auc is not None:
|
||||
line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}"
|
||||
elif name is not None:
|
||||
line_kwargs["label"] = name
|
||||
|
||||
line_kwargs.update(**kwargs)
|
||||
|
||||
self.line_ = ax.plot(self.fpr, self.tpr, **line_kwargs)[0]
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
|
||||
if "label" in line_kwargs:
|
||||
ax.legend(loc='lower right')
|
||||
|
||||
self.ax_ = ax
|
||||
self.figure_ = ax.figure
|
||||
return self
|
||||
|
||||
|
||||
@_deprecate_positional_args
|
||||
def plot_roc_curve(estimator, X, y, *, sample_weight=None,
|
||||
drop_intermediate=True, response_method="auto",
|
||||
name=None, ax=None, **kwargs):
|
||||
"""Plot Receiver operating characteristic (ROC) curve.
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's `plot`.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
drop_intermediate : boolean, default=True
|
||||
Whether to drop some suboptimal thresholds which would not appear
|
||||
on a plotted ROC curve. This is useful in order to create lighter
|
||||
ROC curves.
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'} \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
name : str, default=None
|
||||
Name of ROC Curve for labeling. If `None`, use the name of the
|
||||
estimator.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.RocCurveDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt # doctest: +SKIP
|
||||
>>> from sklearn import datasets, metrics, model_selection, svm
|
||||
>>> X, y = datasets.make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = model_selection.train_test_split(\
|
||||
X, y, random_state=0)
|
||||
>>> clf = svm.SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> metrics.plot_roc_curve(clf, X_test, y_test) # doctest: +SKIP
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
"""
|
||||
check_matplotlib_support('plot_roc_curve')
|
||||
|
||||
classification_error = (
|
||||
"{} should be a binary classifier".format(estimator.__class__.__name__)
|
||||
)
|
||||
if not is_classifier(estimator):
|
||||
raise ValueError(classification_error)
|
||||
|
||||
prediction_method = _check_classifer_response_method(estimator,
|
||||
response_method)
|
||||
y_pred = prediction_method(X)
|
||||
|
||||
if y_pred.ndim != 1:
|
||||
if y_pred.shape[1] != 2:
|
||||
raise ValueError(classification_error)
|
||||
else:
|
||||
y_pred = y_pred[:, 1]
|
||||
|
||||
pos_label = estimator.classes_[1]
|
||||
fpr, tpr, _ = roc_curve(y, y_pred, pos_label=pos_label,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
name = estimator.__class__.__name__ if name is None else name
|
||||
viz = RocCurveDisplay(
|
||||
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name
|
||||
)
|
||||
return viz.plot(ax=ax, name=name, **kwargs)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,299 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVC, SVR
|
||||
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from sklearn.metrics import plot_confusion_matrix
|
||||
from sklearn.metrics import ConfusionMatrixDisplay
|
||||
|
||||
|
||||
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||||
"matplotlib.*")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def n_classes():
|
||||
return 5
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data(n_classes):
|
||||
X, y = make_classification(n_samples=100, n_informative=5,
|
||||
n_classes=n_classes, random_state=0)
|
||||
return X, y
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def fitted_clf(data):
|
||||
return SVC(kernel='linear', C=0.01).fit(*data)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def y_pred(data, fitted_clf):
|
||||
X, _ = data
|
||||
return fitted_clf.predict(X)
|
||||
|
||||
|
||||
def test_error_on_regressor(pyplot, data):
|
||||
X, y = data
|
||||
est = SVR().fit(X, y)
|
||||
|
||||
msg = "plot_confusion_matrix only supports classifiers"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_confusion_matrix(est, X, y)
|
||||
|
||||
|
||||
def test_error_on_invalid_option(pyplot, fitted_clf, data):
|
||||
X, y = data
|
||||
msg = (r"normalize must be one of \{'true', 'pred', 'all', "
|
||||
r"None\}")
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_confusion_matrix(fitted_clf, X, y, normalize='invalid')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("with_labels", [True, False])
|
||||
@pytest.mark.parametrize("with_display_labels", [True, False])
|
||||
def test_plot_confusion_matrix_custom_labels(pyplot, data, y_pred, fitted_clf,
|
||||
n_classes, with_labels,
|
||||
with_display_labels):
|
||||
X, y = data
|
||||
ax = pyplot.gca()
|
||||
labels = [2, 1, 0, 3, 4] if with_labels else None
|
||||
display_labels = ['b', 'd', 'a', 'e', 'f'] if with_display_labels else None
|
||||
|
||||
cm = confusion_matrix(y, y_pred, labels=labels)
|
||||
disp = plot_confusion_matrix(fitted_clf, X, y,
|
||||
ax=ax, display_labels=display_labels,
|
||||
labels=labels)
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
|
||||
if with_display_labels:
|
||||
expected_display_labels = display_labels
|
||||
elif with_labels:
|
||||
expected_display_labels = labels
|
||||
else:
|
||||
expected_display_labels = list(range(n_classes))
|
||||
|
||||
expected_display_labels_str = [str(name)
|
||||
for name in expected_display_labels]
|
||||
|
||||
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||||
|
||||
assert_array_equal(disp.display_labels, expected_display_labels)
|
||||
assert_array_equal(x_ticks, expected_display_labels_str)
|
||||
assert_array_equal(y_ticks, expected_display_labels_str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None])
|
||||
@pytest.mark.parametrize("include_values", [True, False])
|
||||
def test_plot_confusion_matrix(pyplot, data, y_pred, n_classes, fitted_clf,
|
||||
normalize, include_values):
|
||||
X, y = data
|
||||
ax = pyplot.gca()
|
||||
cmap = 'plasma'
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
disp = plot_confusion_matrix(fitted_clf, X, y,
|
||||
normalize=normalize,
|
||||
cmap=cmap, ax=ax,
|
||||
include_values=include_values)
|
||||
|
||||
assert disp.ax_ == ax
|
||||
|
||||
if normalize == 'true':
|
||||
cm = cm / cm.sum(axis=1, keepdims=True)
|
||||
elif normalize == 'pred':
|
||||
cm = cm / cm.sum(axis=0, keepdims=True)
|
||||
elif normalize == 'all':
|
||||
cm = cm / cm.sum()
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
import matplotlib as mpl
|
||||
assert isinstance(disp.im_, mpl.image.AxesImage)
|
||||
assert disp.im_.get_cmap().name == cmap
|
||||
assert isinstance(disp.ax_, pyplot.Axes)
|
||||
assert isinstance(disp.figure_, pyplot.Figure)
|
||||
|
||||
assert disp.ax_.get_ylabel() == "True label"
|
||||
assert disp.ax_.get_xlabel() == "Predicted label"
|
||||
|
||||
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||||
|
||||
expected_display_labels = list(range(n_classes))
|
||||
|
||||
expected_display_labels_str = [str(name)
|
||||
for name in expected_display_labels]
|
||||
|
||||
assert_array_equal(disp.display_labels, expected_display_labels)
|
||||
assert_array_equal(x_ticks, expected_display_labels_str)
|
||||
assert_array_equal(y_ticks, expected_display_labels_str)
|
||||
|
||||
image_data = disp.im_.get_array().data
|
||||
assert_allclose(image_data, cm)
|
||||
|
||||
if include_values:
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
fmt = '.2g'
|
||||
expected_text = np.array([format(v, fmt) for v in cm.ravel(order="C")])
|
||||
text_text = np.array([
|
||||
t.get_text() for t in disp.text_.ravel(order="C")])
|
||||
assert_array_equal(expected_text, text_text)
|
||||
else:
|
||||
assert disp.text_ is None
|
||||
|
||||
|
||||
def test_confusion_matrix_display(pyplot, data, fitted_clf, y_pred, n_classes):
|
||||
X, y = data
|
||||
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
disp = plot_confusion_matrix(fitted_clf, X, y, normalize=None,
|
||||
include_values=True, cmap='viridis',
|
||||
xticks_rotation=45.0)
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
|
||||
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
|
||||
assert_allclose(rotations, 45.0)
|
||||
|
||||
image_data = disp.im_.get_array().data
|
||||
assert_allclose(image_data, cm)
|
||||
|
||||
disp.plot(cmap='plasma')
|
||||
assert disp.im_.get_cmap().name == 'plasma'
|
||||
|
||||
disp.plot(include_values=False)
|
||||
assert disp.text_ is None
|
||||
|
||||
disp.plot(xticks_rotation=90.0)
|
||||
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
|
||||
assert_allclose(rotations, 90.0)
|
||||
|
||||
disp.plot(values_format='e')
|
||||
expected_text = np.array([format(v, 'e') for v in cm.ravel(order="C")])
|
||||
text_text = np.array([
|
||||
t.get_text() for t in disp.text_.ravel(order="C")])
|
||||
assert_array_equal(expected_text, text_text)
|
||||
|
||||
|
||||
def test_confusion_matrix_contrast(pyplot):
|
||||
# make sure text color is appropriate depending on background
|
||||
|
||||
cm = np.eye(2) / 2
|
||||
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.gray)
|
||||
# diagonal text is black
|
||||
assert_allclose(disp.text_[0, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
|
||||
# off-diagonal text is white
|
||||
assert_allclose(disp.text_[0, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.gray_r)
|
||||
# diagonal text is white
|
||||
assert_allclose(disp.text_[0, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
|
||||
# off-diagonal text is black
|
||||
assert_allclose(disp.text_[0, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
# Regression test for #15920
|
||||
cm = np.array([[19, 34], [32, 58]])
|
||||
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.Blues)
|
||||
min_color = pyplot.cm.Blues(0)
|
||||
max_color = pyplot.cm.Blues(255)
|
||||
assert_allclose(disp.text_[0, 0].get_color(), max_color)
|
||||
assert_allclose(disp.text_[0, 1].get_color(), max_color)
|
||||
assert_allclose(disp.text_[1, 0].get_color(), max_color)
|
||||
assert_allclose(disp.text_[1, 1].get_color(), min_color)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf", [LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(make_column_transformer((StandardScaler(), [0, 1])),
|
||||
LogisticRegression())])
|
||||
def test_confusion_matrix_pipeline(pyplot, clf, data, n_classes):
|
||||
X, y = data
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_confusion_matrix(clf, X, y)
|
||||
clf.fit(X, y)
|
||||
y_pred = clf.predict(X)
|
||||
|
||||
disp = plot_confusion_matrix(clf, X, y)
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("values_format", ['e', 'n'])
|
||||
def test_confusion_matrix_text_format(pyplot, data, y_pred, n_classes,
|
||||
fitted_clf, values_format):
|
||||
# Make sure plot text is formatted with 'values_format'.
|
||||
X, y = data
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
disp = plot_confusion_matrix(fitted_clf, X, y,
|
||||
include_values=True,
|
||||
values_format=values_format)
|
||||
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
|
||||
expected_text = np.array([format(v, values_format)
|
||||
for v in cm.ravel()])
|
||||
text_text = np.array([
|
||||
t.get_text() for t in disp.text_.ravel()])
|
||||
assert_array_equal(expected_text, text_text)
|
||||
|
||||
|
||||
def test_confusion_matrix_standard_format(pyplot):
|
||||
cm = np.array([[10000000, 0], [123456, 12345678]])
|
||||
plotted_text = ConfusionMatrixDisplay(
|
||||
cm, display_labels=[False, True]).plot().text_
|
||||
# Values should be shown as whole numbers 'd',
|
||||
# except the first number which should be shown as 1e+07 (longer length)
|
||||
# and the last number will be shown as 1.2e+07 (longer length)
|
||||
test = [t.get_text() for t in plotted_text.ravel()]
|
||||
assert test == ['1e+07', '0', '123456', '1.2e+07']
|
||||
|
||||
cm = np.array([[0.1, 10], [100, 0.525]])
|
||||
plotted_text = ConfusionMatrixDisplay(
|
||||
cm, display_labels=[False, True]).plot().text_
|
||||
# Values should now formatted as '.2g', since there's a float in
|
||||
# Values are have two dec places max, (e.g 100 becomes 1e+02)
|
||||
test = [t.get_text() for t in plotted_text.ravel()]
|
||||
assert test == ['0.1', '10', '1e+02', '0.53']
|
||||
|
||||
|
||||
@pytest.mark.parametrize("display_labels, expected_labels", [
|
||||
(None, ["0", "1"]),
|
||||
(["cat", "dog"], ["cat", "dog"]),
|
||||
])
|
||||
def test_default_labels(pyplot, display_labels, expected_labels):
|
||||
cm = np.array([[10, 0], [12, 120]])
|
||||
disp = ConfusionMatrixDisplay(cm, display_labels=display_labels).plot()
|
||||
|
||||
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||||
|
||||
assert_array_equal(x_ticks, expected_labels)
|
||||
assert_array_equal(y_ticks, expected_labels)
|
|
@ -0,0 +1,192 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.metrics import plot_precision_recall_curve
|
||||
from sklearn.metrics import PrecisionRecallDisplay
|
||||
from sklearn.metrics import average_precision_score
|
||||
from sklearn.metrics import precision_recall_curve
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.compose import make_column_transformer
|
||||
|
||||
|
||||
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||||
"matplotlib.*")
|
||||
|
||||
|
||||
def test_errors(pyplot):
|
||||
X, y_multiclass = make_classification(n_classes=3, n_samples=50,
|
||||
n_informative=3,
|
||||
random_state=0)
|
||||
y_binary = y_multiclass == 0
|
||||
|
||||
# Unfitted classifer
|
||||
binary_clf = DecisionTreeClassifier()
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_precision_recall_curve(binary_clf, X, y_binary)
|
||||
binary_clf.fit(X, y_binary)
|
||||
|
||||
multi_clf = DecisionTreeClassifier().fit(X, y_multiclass)
|
||||
|
||||
# Fitted multiclass classifier with binary data
|
||||
msg = "DecisionTreeClassifier should be a binary classifier"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_precision_recall_curve(multi_clf, X, y_binary)
|
||||
|
||||
reg = DecisionTreeRegressor().fit(X, y_multiclass)
|
||||
msg = "DecisionTreeRegressor should be a binary classifier"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_precision_recall_curve(reg, X, y_binary)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_method, msg",
|
||||
[("predict_proba", "response method predict_proba is not defined in "
|
||||
"MyClassifier"),
|
||||
("decision_function", "response method decision_function is not defined "
|
||||
"in MyClassifier"),
|
||||
("auto", "response method decision_function or predict_proba is not "
|
||||
"defined in MyClassifier"),
|
||||
("bad_method", "response_method must be 'predict_proba', "
|
||||
"'decision_function' or 'auto'")])
|
||||
def test_error_bad_response(pyplot, response_method, msg):
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
|
||||
class MyClassifier(BaseEstimator, ClassifierMixin):
|
||||
def fit(self, X, y):
|
||||
self.fitted_ = True
|
||||
self.classes_ = [0, 1]
|
||||
return self
|
||||
|
||||
clf = MyClassifier().fit(X, y)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_precision_recall_curve(clf, X, y, response_method=response_method)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_method",
|
||||
["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("with_sample_weight", [True, False])
|
||||
def test_plot_precision_recall(pyplot, response_method, with_sample_weight):
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
|
||||
lr = LogisticRegression().fit(X, y)
|
||||
|
||||
if with_sample_weight:
|
||||
rng = np.random.RandomState(42)
|
||||
sample_weight = rng.randint(0, 4, size=X.shape[0])
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
disp = plot_precision_recall_curve(lr, X, y, alpha=0.8,
|
||||
response_method=response_method,
|
||||
sample_weight=sample_weight)
|
||||
|
||||
y_score = getattr(lr, response_method)(X)
|
||||
if response_method == 'predict_proba':
|
||||
y_score = y_score[:, 1]
|
||||
|
||||
prec, recall, _ = precision_recall_curve(y, y_score,
|
||||
sample_weight=sample_weight)
|
||||
avg_prec = average_precision_score(y, y_score, sample_weight=sample_weight)
|
||||
|
||||
assert_allclose(disp.precision, prec)
|
||||
assert_allclose(disp.recall, recall)
|
||||
assert disp.average_precision == pytest.approx(avg_prec)
|
||||
|
||||
assert disp.estimator_name == "LogisticRegression"
|
||||
|
||||
# cannot fail thanks to pyplot fixture
|
||||
import matplotlib as mpl # noqa
|
||||
assert isinstance(disp.line_, mpl.lines.Line2D)
|
||||
assert disp.line_.get_alpha() == 0.8
|
||||
assert isinstance(disp.ax_, mpl.axes.Axes)
|
||||
assert isinstance(disp.figure_, mpl.figure.Figure)
|
||||
|
||||
expected_label = "LogisticRegression (AP = {:0.2f})".format(avg_prec)
|
||||
assert disp.line_.get_label() == expected_label
|
||||
assert disp.ax_.get_xlabel() == "Recall"
|
||||
assert disp.ax_.get_ylabel() == "Precision"
|
||||
|
||||
# draw again with another label
|
||||
disp.plot(name="MySpecialEstimator")
|
||||
expected_label = "MySpecialEstimator (AP = {:0.2f})".format(avg_prec)
|
||||
assert disp.line_.get_label() == expected_label
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf", [make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(make_column_transformer((StandardScaler(), [0, 1])),
|
||||
LogisticRegression())])
|
||||
def test_precision_recall_curve_pipeline(pyplot, clf):
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_precision_recall_curve(clf, X, y)
|
||||
clf.fit(X, y)
|
||||
disp = plot_precision_recall_curve(clf, X, y)
|
||||
assert disp.estimator_name == clf.__class__.__name__
|
||||
|
||||
|
||||
def test_precision_recall_curve_string_labels(pyplot):
|
||||
# regression test #15738
|
||||
cancer = load_breast_cancer()
|
||||
X = cancer.data
|
||||
y = cancer.target_names[cancer.target]
|
||||
|
||||
lr = make_pipeline(StandardScaler(), LogisticRegression())
|
||||
lr.fit(X, y)
|
||||
for klass in cancer.target_names:
|
||||
assert klass in lr.classes_
|
||||
disp = plot_precision_recall_curve(lr, X, y)
|
||||
|
||||
y_pred = lr.predict_proba(X)[:, 1]
|
||||
avg_prec = average_precision_score(y, y_pred,
|
||||
pos_label=lr.classes_[1])
|
||||
|
||||
assert disp.average_precision == pytest.approx(avg_prec)
|
||||
assert disp.estimator_name == lr.__class__.__name__
|
||||
|
||||
|
||||
def test_plot_precision_recall_curve_estimator_name_multiple_calls(pyplot):
|
||||
# non-regression test checking that the `name` used when calling
|
||||
# `plot_roc_curve` is used as well when calling `disp.plot()`
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
clf_name = "my hand-crafted name"
|
||||
clf = LogisticRegression().fit(X, y)
|
||||
disp = plot_precision_recall_curve(clf, X, y, name=clf_name)
|
||||
assert disp.estimator_name == clf_name
|
||||
pyplot.close("all")
|
||||
disp.plot()
|
||||
assert clf_name in disp.line_.get_label()
|
||||
pyplot.close("all")
|
||||
clf_name = "another_name"
|
||||
disp.plot(name=clf_name)
|
||||
assert clf_name in disp.line_.get_label()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"average_precision, estimator_name, expected_label",
|
||||
[
|
||||
(0.9, None, "AP = 0.90"),
|
||||
(None, "my_est", "my_est"),
|
||||
(0.8, "my_est2", "my_est2 (AP = 0.80)"),
|
||||
]
|
||||
)
|
||||
def test_default_labels(pyplot, average_precision, estimator_name,
|
||||
expected_label):
|
||||
prec = np.array([1, 0.5, 0])
|
||||
recall = np.array([0, 0.5, 1])
|
||||
disp = PrecisionRecallDisplay(prec, recall,
|
||||
average_precision=average_precision,
|
||||
estimator_name=estimator_name)
|
||||
disp.plot()
|
||||
assert disp.line_.get_label() == expected_label
|
|
@ -0,0 +1,170 @@
|
|||
import pytest
|
||||
from numpy.testing import assert_allclose
|
||||
import numpy as np
|
||||
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.metrics import plot_roc_curve
|
||||
from sklearn.metrics import RocCurveDisplay
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
from sklearn.base import ClassifierMixin
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.compose import make_column_transformer
|
||||
|
||||
|
||||
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||||
"matplotlib.*")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data():
|
||||
return load_iris(return_X_y=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data_binary(data):
|
||||
X, y = data
|
||||
return X[y < 2], y[y < 2]
|
||||
|
||||
|
||||
def test_plot_roc_curve_error_non_binary(pyplot, data):
|
||||
X, y = data
|
||||
clf = DecisionTreeClassifier()
|
||||
clf.fit(X, y)
|
||||
|
||||
msg = "DecisionTreeClassifier should be a binary classifier"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_roc_curve(clf, X, y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_method, msg",
|
||||
[("predict_proba", "response method predict_proba is not defined in "
|
||||
"MyClassifier"),
|
||||
("decision_function", "response method decision_function is not defined "
|
||||
"in MyClassifier"),
|
||||
("auto", "response method decision_function or predict_proba is not "
|
||||
"defined in MyClassifier"),
|
||||
("bad_method", "response_method must be 'predict_proba', "
|
||||
"'decision_function' or 'auto'")])
|
||||
def test_plot_roc_curve_error_no_response(pyplot, data_binary, response_method,
|
||||
msg):
|
||||
X, y = data_binary
|
||||
|
||||
class MyClassifier(ClassifierMixin):
|
||||
def fit(self, X, y):
|
||||
self.classes_ = [0, 1]
|
||||
return self
|
||||
|
||||
clf = MyClassifier().fit(X, y)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_roc_curve(clf, X, y, response_method=response_method)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_method",
|
||||
["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("with_sample_weight", [True, False])
|
||||
@pytest.mark.parametrize("drop_intermediate", [True, False])
|
||||
@pytest.mark.parametrize("with_strings", [True, False])
|
||||
def test_plot_roc_curve(pyplot, response_method, data_binary,
|
||||
with_sample_weight, drop_intermediate,
|
||||
with_strings):
|
||||
X, y = data_binary
|
||||
|
||||
pos_label = None
|
||||
if with_strings:
|
||||
y = np.array(["c", "b"])[y]
|
||||
pos_label = "c"
|
||||
|
||||
if with_sample_weight:
|
||||
rng = np.random.RandomState(42)
|
||||
sample_weight = rng.randint(1, 4, size=(X.shape[0]))
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
lr = LogisticRegression()
|
||||
lr.fit(X, y)
|
||||
|
||||
viz = plot_roc_curve(lr, X, y, alpha=0.8, sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate)
|
||||
|
||||
y_pred = getattr(lr, response_method)(X)
|
||||
if y_pred.ndim == 2:
|
||||
y_pred = y_pred[:, 1]
|
||||
|
||||
fpr, tpr, _ = roc_curve(y, y_pred, sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
pos_label=pos_label)
|
||||
|
||||
assert_allclose(viz.roc_auc, auc(fpr, tpr))
|
||||
assert_allclose(viz.fpr, fpr)
|
||||
assert_allclose(viz.tpr, tpr)
|
||||
|
||||
assert viz.estimator_name == "LogisticRegression"
|
||||
|
||||
# cannot fail thanks to pyplot fixture
|
||||
import matplotlib as mpl # noqal
|
||||
assert isinstance(viz.line_, mpl.lines.Line2D)
|
||||
assert viz.line_.get_alpha() == 0.8
|
||||
assert isinstance(viz.ax_, mpl.axes.Axes)
|
||||
assert isinstance(viz.figure_, mpl.figure.Figure)
|
||||
|
||||
expected_label = "LogisticRegression (AUC = {:0.2f})".format(viz.roc_auc)
|
||||
assert viz.line_.get_label() == expected_label
|
||||
assert viz.ax_.get_ylabel() == "True Positive Rate"
|
||||
assert viz.ax_.get_xlabel() == "False Positive Rate"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf", [LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(make_column_transformer((StandardScaler(), [0, 1])),
|
||||
LogisticRegression())])
|
||||
def test_roc_curve_not_fitted_errors(pyplot, data_binary, clf):
|
||||
X, y = data_binary
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_roc_curve(clf, X, y)
|
||||
clf.fit(X, y)
|
||||
disp = plot_roc_curve(clf, X, y)
|
||||
assert clf.__class__.__name__ in disp.line_.get_label()
|
||||
assert disp.estimator_name == clf.__class__.__name__
|
||||
|
||||
|
||||
def test_plot_roc_curve_estimator_name_multiple_calls(pyplot, data_binary):
|
||||
# non-regression test checking that the `name` used when calling
|
||||
# `plot_roc_curve` is used as well when calling `disp.plot()`
|
||||
X, y = data_binary
|
||||
clf_name = "my hand-crafted name"
|
||||
clf = LogisticRegression().fit(X, y)
|
||||
disp = plot_roc_curve(clf, X, y, name=clf_name)
|
||||
assert disp.estimator_name == clf_name
|
||||
pyplot.close("all")
|
||||
disp.plot()
|
||||
assert clf_name in disp.line_.get_label()
|
||||
pyplot.close("all")
|
||||
clf_name = "another_name"
|
||||
disp.plot(name=clf_name)
|
||||
assert clf_name in disp.line_.get_label()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"roc_auc, estimator_name, expected_label",
|
||||
[
|
||||
(0.9, None, "AUC = 0.90"),
|
||||
(None, "my_est", "my_est"),
|
||||
(0.8, "my_est2", "my_est2 (AUC = 0.80)")
|
||||
]
|
||||
)
|
||||
def test_default_labels(pyplot, roc_auc, estimator_name,
|
||||
expected_label):
|
||||
fpr = np.array([0, 0.5, 1])
|
||||
tpr = np.array([0, 0.5, 1])
|
||||
disp = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
|
||||
estimator_name=estimator_name).plot()
|
||||
assert disp.line_.get_label() == expected_label
|
Loading…
Add table
Add a link
Reference in a new issue