Vehicle-Anti-Theft-Face-Rec.../venv/Lib/site-packages/sklearn/datasets/tests/conftest.py

76 lines
2.2 KiB
Python
Raw Normal View History

2020-11-12 16:05:57 +00:00
""" Network tests are only run, if data is already locally available,
or if download is specifically requested by environment variable."""
import builtins
from os import environ
import pytest
from sklearn.datasets import fetch_20newsgroups
from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.datasets import fetch_california_housing
from sklearn.datasets import fetch_covtype
from sklearn.datasets import fetch_kddcup99
from sklearn.datasets import fetch_olivetti_faces
from sklearn.datasets import fetch_rcv1
def _wrapped_fetch(f, dataset_name):
""" Fetch dataset (download if missing and requested by environment) """
download_if_missing = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0'
def wrapped(*args, **kwargs):
kwargs['download_if_missing'] = download_if_missing
try:
return f(*args, **kwargs)
except IOError:
pytest.skip("Download {} to run this test".format(dataset_name))
return wrapped
@pytest.fixture
def fetch_20newsgroups_fxt():
return _wrapped_fetch(fetch_20newsgroups, dataset_name='20newsgroups')
@pytest.fixture
def fetch_20newsgroups_vectorized_fxt():
return _wrapped_fetch(fetch_20newsgroups_vectorized,
dataset_name='20newsgroups_vectorized')
@pytest.fixture
def fetch_california_housing_fxt():
return _wrapped_fetch(fetch_california_housing,
dataset_name='california_housing')
@pytest.fixture
def fetch_covtype_fxt():
return _wrapped_fetch(fetch_covtype, dataset_name='covtype')
@pytest.fixture
def fetch_kddcup99_fxt():
return _wrapped_fetch(fetch_kddcup99, dataset_name='kddcup99')
@pytest.fixture
def fetch_olivetti_faces_fxt():
return _wrapped_fetch(fetch_olivetti_faces, dataset_name='olivetti_faces')
@pytest.fixture
def fetch_rcv1_fxt():
return _wrapped_fetch(fetch_rcv1, dataset_name='rcv1')
@pytest.fixture
def hide_available_pandas(monkeypatch):
""" Pretend pandas was not installed. """
import_orig = builtins.__import__
def mocked_import(name, *args, **kwargs):
if name == 'pandas':
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)