Vehicle-Anti-Theft-Face-Rec.../venv/Lib/site-packages/sklearn/tests/test_metaestimators.py

147 lines
5.2 KiB
Python

"""Common tests for metaestimators"""
import functools
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.datasets import make_classification
from sklearn.utils._testing import assert_raises
from sklearn.utils.validation import check_is_fitted
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.feature_selection import RFE, RFECV
from sklearn.ensemble import BaggingClassifier
from sklearn.exceptions import NotFittedError
class DelegatorData:
def __init__(self, name, construct, skip_methods=(),
fit_args=make_classification()):
self.name = name
self.construct = construct
self.fit_args = fit_args
self.skip_methods = skip_methods
DELEGATING_METAESTIMATORS = [
DelegatorData('Pipeline', lambda est: Pipeline([('est', est)])),
DelegatorData('GridSearchCV',
lambda est: GridSearchCV(
est, param_grid={'param': [5]}, cv=2),
skip_methods=['score']),
DelegatorData('RandomizedSearchCV',
lambda est: RandomizedSearchCV(
est, param_distributions={'param': [5]}, cv=2, n_iter=1),
skip_methods=['score']),
DelegatorData('RFE', RFE,
skip_methods=['transform', 'inverse_transform']),
DelegatorData('RFECV', RFECV,
skip_methods=['transform', 'inverse_transform']),
DelegatorData('BaggingClassifier', BaggingClassifier,
skip_methods=['transform', 'inverse_transform', 'score',
'predict_proba', 'predict_log_proba',
'predict'])
]
def test_metaestimator_delegation():
# Ensures specified metaestimators have methods iff subestimator does
def hides(method):
@property
def wrapper(obj):
if obj.hidden_method == method.__name__:
raise AttributeError('%r is hidden' % obj.hidden_method)
return functools.partial(method, obj)
return wrapper
class SubEstimator(BaseEstimator):
def __init__(self, param=1, hidden_method=None):
self.param = param
self.hidden_method = hidden_method
def fit(self, X, y=None, *args, **kwargs):
self.coef_ = np.arange(X.shape[1])
return True
def _check_fit(self):
check_is_fitted(self)
@hides
def inverse_transform(self, X, *args, **kwargs):
self._check_fit()
return X
@hides
def transform(self, X, *args, **kwargs):
self._check_fit()
return X
@hides
def predict(self, X, *args, **kwargs):
self._check_fit()
return np.ones(X.shape[0])
@hides
def predict_proba(self, X, *args, **kwargs):
self._check_fit()
return np.ones(X.shape[0])
@hides
def predict_log_proba(self, X, *args, **kwargs):
self._check_fit()
return np.ones(X.shape[0])
@hides
def decision_function(self, X, *args, **kwargs):
self._check_fit()
return np.ones(X.shape[0])
@hides
def score(self, X, y, *args, **kwargs):
self._check_fit()
return 1.0
methods = [k for k in SubEstimator.__dict__.keys()
if not k.startswith('_') and not k.startswith('fit')]
methods.sort()
for delegator_data in DELEGATING_METAESTIMATORS:
delegate = SubEstimator()
delegator = delegator_data.construct(delegate)
for method in methods:
if method in delegator_data.skip_methods:
continue
assert hasattr(delegate, method)
assert hasattr(delegator, method), (
"%s does not have method %r when its delegate does"
% (delegator_data.name, method))
# delegation before fit raises a NotFittedError
if method == 'score':
assert_raises(NotFittedError, getattr(delegator, method),
delegator_data.fit_args[0],
delegator_data.fit_args[1])
else:
assert_raises(NotFittedError, getattr(delegator, method),
delegator_data.fit_args[0])
delegator.fit(*delegator_data.fit_args)
for method in methods:
if method in delegator_data.skip_methods:
continue
# smoke test delegation
if method == 'score':
getattr(delegator, method)(delegator_data.fit_args[0],
delegator_data.fit_args[1])
else:
getattr(delegator, method)(delegator_data.fit_args[0])
for method in methods:
if method in delegator_data.skip_methods:
continue
delegate = SubEstimator(hidden_method=method)
delegator = delegator_data.construct(delegate)
assert not hasattr(delegate, method)
assert not hasattr(delegator, method), (
"%s has method %r when its delegate does not"
% (delegator_data.name, method))