129 lines
3.9 KiB
Python
129 lines
3.9 KiB
Python
|
import pytest
|
||
|
import types
|
||
|
import numpy as np
|
||
|
import warnings
|
||
|
|
||
|
from sklearn.dummy import DummyClassifier
|
||
|
from sklearn.utils import all_estimators
|
||
|
from sklearn.utils.estimator_checks import choose_check_classifiers_labels
|
||
|
from sklearn.utils.estimator_checks import NotAnArray
|
||
|
from sklearn.utils.estimator_checks import enforce_estimator_tags_y
|
||
|
from sklearn.utils.estimator_checks import is_public_parameter
|
||
|
from sklearn.utils.estimator_checks import pairwise_estimator_convert_X
|
||
|
from sklearn.utils.estimator_checks import set_checking_parameters
|
||
|
from sklearn.utils.optimize import newton_cg
|
||
|
from sklearn.utils.random import random_choice_csc
|
||
|
from sklearn.utils import safe_indexing
|
||
|
|
||
|
|
||
|
# This file tests the utils that are deprecated
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_choose_check_classifiers_labels_deprecated():
|
||
|
with pytest.warns(FutureWarning,
|
||
|
match="removed in version 0.24"):
|
||
|
choose_check_classifiers_labels(None, None, None)
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_enforce_estimator_tags_y():
|
||
|
with pytest.warns(FutureWarning,
|
||
|
match="removed in version 0.24"):
|
||
|
enforce_estimator_tags_y(DummyClassifier(), np.array([0, 1]))
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_notanarray():
|
||
|
with pytest.warns(FutureWarning,
|
||
|
match="removed in version 0.24"):
|
||
|
NotAnArray([1, 2])
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_is_public_parameter():
|
||
|
with pytest.warns(FutureWarning,
|
||
|
match="removed in version 0.24"):
|
||
|
is_public_parameter('hello')
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_pairwise_estimator_convert_X():
|
||
|
with pytest.warns(FutureWarning,
|
||
|
match="removed in version 0.24"):
|
||
|
pairwise_estimator_convert_X([[1, 2]], DummyClassifier())
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_set_checking_parameters():
|
||
|
with pytest.warns(FutureWarning,
|
||
|
match="removed in version 0.24"):
|
||
|
set_checking_parameters(DummyClassifier())
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_newton_cg():
|
||
|
rng = np.random.RandomState(0)
|
||
|
A = rng.normal(size=(10, 10))
|
||
|
x0 = np.ones(10)
|
||
|
|
||
|
def func(x):
|
||
|
Ax = A.dot(x)
|
||
|
return .5 * (Ax).dot(Ax)
|
||
|
|
||
|
def grad(x):
|
||
|
return A.T.dot(A.dot(x))
|
||
|
|
||
|
def grad_hess(x):
|
||
|
return grad(x), lambda x: A.T.dot(A.dot(x))
|
||
|
|
||
|
with pytest.warns(FutureWarning,
|
||
|
match="removed in version 0.24"):
|
||
|
newton_cg(grad_hess, func, grad, x0)
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_random_choice_csc():
|
||
|
with pytest.warns(FutureWarning,
|
||
|
match="removed in version 0.24"):
|
||
|
random_choice_csc(10, [[2]])
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_safe_indexing():
|
||
|
with pytest.warns(FutureWarning,
|
||
|
match="removed in version 0.24"):
|
||
|
safe_indexing([1, 2], 0)
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_partial_dependence_no_shadowing():
|
||
|
# Non-regression test for:
|
||
|
# https://github.com/scikit-learn/scikit-learn/issues/15842
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.simplefilter("ignore", category=FutureWarning)
|
||
|
from sklearn.inspection.partial_dependence import partial_dependence as _ # noqa
|
||
|
|
||
|
# Calling all_estimators() also triggers a recursive import of all
|
||
|
# submodules, including deprecated ones.
|
||
|
all_estimators()
|
||
|
|
||
|
from sklearn.inspection import partial_dependence
|
||
|
assert isinstance(partial_dependence, types.FunctionType)
|
||
|
|
||
|
|
||
|
# TODO: remove in 0.24
|
||
|
def test_dict_learning_no_shadowing():
|
||
|
# Non-regression test for:
|
||
|
# https://github.com/scikit-learn/scikit-learn/issues/15842
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.simplefilter("ignore", category=FutureWarning)
|
||
|
from sklearn.decomposition.dict_learning import dict_learning as _ # noqa
|
||
|
|
||
|
# Calling all_estimators() also triggers a recursive import of all
|
||
|
# submodules, including deprecated ones.
|
||
|
all_estimators()
|
||
|
|
||
|
from sklearn.decomposition import dict_learning
|
||
|
assert isinstance(dict_learning, types.FunctionType)
|