Vehicle-Anti-Theft-Face-Rec.../venv/Lib/site-packages/sklearn/tests/test_docstring_parameters.py

258 lines
9.9 KiB
Python

# Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
# Raghav RV <rvraghav93@gmail.com>
# License: BSD 3 clause
import inspect
import warnings
import importlib
from pkgutil import walk_packages
from inspect import signature
import numpy as np
import sklearn
from sklearn.utils import IS_PYPY
from sklearn.utils._testing import check_docstring_parameters
from sklearn.utils._testing import _get_func_name
from sklearn.utils._testing import ignore_warnings
from sklearn.utils._testing import all_estimators
from sklearn.utils.estimator_checks import _enforce_estimator_tags_y
from sklearn.utils.estimator_checks import _enforce_estimator_tags_x
from sklearn.utils.deprecation import _is_deprecated
from sklearn.externals._pep562 import Pep562
from sklearn.datasets import make_classification
import pytest
# walk_packages() ignores DeprecationWarnings, now we need to ignore
# FutureWarnings
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
PUBLIC_MODULES = set([
pckg[1] for pckg in walk_packages(
prefix='sklearn.',
# mypy error: Module has no attribute "__path__"
path=sklearn.__path__) # type: ignore # mypy issue #1422
if not ("._" in pckg[1] or ".tests." in pckg[1])
])
# functions to ignore args / docstring of
_DOCSTRING_IGNORES = [
'sklearn.utils.deprecation.load_mlcomp',
'sklearn.pipeline.make_pipeline',
'sklearn.pipeline.make_union',
'sklearn.utils.extmath.safe_sparse_dot',
'sklearn.utils._joblib'
]
# Methods where y param should be ignored if y=None by default
_METHODS_IGNORE_NONE_Y = [
'fit',
'score',
'fit_predict',
'fit_transform',
'partial_fit',
'predict'
]
# numpydoc 0.8.0's docscrape tool raises because of collections.abc under
# Python 3.7
@pytest.mark.filterwarnings('ignore::FutureWarning')
@pytest.mark.filterwarnings('ignore::DeprecationWarning')
@pytest.mark.skipif(IS_PYPY, reason='test segfaults on PyPy')
def test_docstring_parameters():
# Test module docstring formatting
# Skip test if numpydoc is not found
pytest.importorskip('numpydoc',
reason="numpydoc is required to test the docstrings")
# XXX unreached code as of v0.22
from numpydoc import docscrape
incorrect = []
for name in PUBLIC_MODULES:
if name == 'sklearn.utils.fixes':
# We cannot always control these docstrings
continue
with warnings.catch_warnings(record=True):
module = importlib.import_module(name)
classes = inspect.getmembers(module, inspect.isclass)
# Exclude imported classes
classes = [cls for cls in classes if cls[1].__module__ == name]
for cname, cls in classes:
this_incorrect = []
if cname in _DOCSTRING_IGNORES or cname.startswith('_'):
continue
if inspect.isabstract(cls):
continue
with warnings.catch_warnings(record=True) as w:
cdoc = docscrape.ClassDoc(cls)
if len(w):
raise RuntimeError('Error for __init__ of %s in %s:\n%s'
% (cls, name, w[0]))
cls_init = getattr(cls, '__init__', None)
if _is_deprecated(cls_init):
continue
elif cls_init is not None:
this_incorrect += check_docstring_parameters(
cls.__init__, cdoc)
for method_name in cdoc.methods:
method = getattr(cls, method_name)
if _is_deprecated(method):
continue
param_ignore = None
# Now skip docstring test for y when y is None
# by default for API reason
if method_name in _METHODS_IGNORE_NONE_Y:
sig = signature(method)
if ('y' in sig.parameters and
sig.parameters['y'].default is None):
param_ignore = ['y'] # ignore y for fit and score
result = check_docstring_parameters(
method, ignore=param_ignore)
this_incorrect += result
incorrect += this_incorrect
functions = inspect.getmembers(module, inspect.isfunction)
# Exclude imported functions
functions = [fn for fn in functions if fn[1].__module__ == name]
for fname, func in functions:
# Don't test private methods / functions
if fname.startswith('_'):
continue
if fname == "configuration" and name.endswith("setup"):
continue
name_ = _get_func_name(func)
if (not any(d in name_ for d in _DOCSTRING_IGNORES) and
not _is_deprecated(func)):
incorrect += check_docstring_parameters(func)
msg = '\n'.join(incorrect)
if len(incorrect) > 0:
raise AssertionError("Docstring Error:\n" + msg)
@ignore_warnings(category=FutureWarning)
def test_tabs():
# Test that there are no tabs in our source files
for importer, modname, ispkg in walk_packages(sklearn.__path__,
prefix='sklearn.'):
if IS_PYPY and ('_svmlight_format_io' in modname or
'feature_extraction._hashing_fast' in modname):
continue
# because we don't import
mod = importlib.import_module(modname)
# TODO: Remove when minimum python version is 3.7
# unwrap to get module because Pep562 backport wraps the original
# module
if isinstance(mod, Pep562):
mod = mod._module
try:
source = inspect.getsource(mod)
except IOError: # user probably should have run "make clean"
continue
assert '\t' not in source, ('"%s" has tabs, please remove them ',
'or add it to the ignore list'
% modname)
@pytest.mark.parametrize('name, Estimator',
all_estimators())
def test_fit_docstring_attributes(name, Estimator):
pytest.importorskip('numpydoc')
from numpydoc import docscrape
doc = docscrape.ClassDoc(Estimator)
attributes = doc['Attributes']
IGNORED = {'ClassifierChain', 'ColumnTransformer', 'CountVectorizer',
'DictVectorizer', 'FeatureUnion', 'GaussianRandomProjection',
'GridSearchCV', 'MultiOutputClassifier', 'MultiOutputRegressor',
'NoSampleWeightWrapper', 'OneVsOneClassifier',
'OneVsRestClassifier', 'OutputCodeClassifier', 'Pipeline',
'RFE', 'RFECV', 'RandomizedSearchCV', 'RegressorChain',
'SelectFromModel', 'SparseCoder', 'SparseRandomProjection',
'SpectralBiclustering', 'StackingClassifier',
'StackingRegressor', 'TfidfVectorizer', 'VotingClassifier',
'VotingRegressor'}
if Estimator.__name__ in IGNORED or Estimator.__name__.startswith('_'):
pytest.skip("Estimator cannot be fit easily to test fit attributes")
est = Estimator()
if Estimator.__name__ == 'SelectKBest':
est.k = 2
if Estimator.__name__ == 'DummyClassifier':
est.strategy = "stratified"
# TO BE REMOVED for v0.25 (avoid FutureWarning)
if Estimator.__name__ == 'AffinityPropagation':
est.random_state = 63
X, y = make_classification(n_samples=20, n_features=3,
n_redundant=0, n_classes=2,
random_state=2)
y = _enforce_estimator_tags_y(est, y)
X = _enforce_estimator_tags_x(est, X)
if '1dlabels' in est._get_tags()['X_types']:
est.fit(y)
elif '2dlabels' in est._get_tags()['X_types']:
est.fit(np.c_[y, y])
else:
est.fit(X, y)
skipped_attributes = {'n_features_in_'}
for attr in attributes:
if attr.name in skipped_attributes:
continue
desc = ' '.join(attr.desc).lower()
# As certain attributes are present "only" if a certain parameter is
# provided, this checks if the word "only" is present in the attribute
# description, and if not the attribute is required to be present.
if 'only ' not in desc:
assert hasattr(est, attr.name)
IGNORED = {'BayesianRidge', 'Birch', 'CCA', 'CategoricalNB', 'ElasticNet',
'ElasticNetCV', 'GaussianProcessClassifier',
'GradientBoostingRegressor', 'HistGradientBoostingClassifier',
'HistGradientBoostingRegressor', 'IsolationForest',
'KNeighborsClassifier', 'KNeighborsRegressor',
'KNeighborsTransformer', 'KernelCenterer', 'KernelDensity',
'LarsCV', 'Lasso', 'LassoLarsCV', 'LassoLarsIC',
'LatentDirichletAllocation', 'LocalOutlierFactor', 'MDS',
'MiniBatchKMeans', 'MLPClassifier', 'MLPRegressor',
'MultiTaskElasticNet', 'MultiTaskElasticNetCV',
'MultiTaskLasso', 'MultiTaskLassoCV', 'NearestNeighbors',
'NuSVR', 'OneClassSVM', 'OrthogonalMatchingPursuit',
'PLSCanonical', 'PLSRegression', 'PLSSVD',
'PassiveAggressiveClassifier', 'Perceptron', 'RBFSampler',
'RadiusNeighborsClassifier', 'RadiusNeighborsRegressor',
'RadiusNeighborsTransformer', 'RandomTreesEmbedding', 'SVR',
'SkewedChi2Sampler'}
if Estimator.__name__ in IGNORED:
pytest.xfail(
reason="Classifier has too many undocumented attributes.")
fit_attr = [k for k in est.__dict__.keys() if k.endswith('_')
and not k.startswith('_')]
fit_attr_names = [attr.name for attr in attributes]
undocumented_attrs = set(fit_attr).difference(fit_attr_names)
undocumented_attrs = set(undocumented_attrs).difference(skipped_attributes)
assert not undocumented_attrs,\
"Undocumented attributes: {}".format(undocumented_attrs)