76 lines
2.2 KiB
Python
76 lines
2.2 KiB
Python
|
""" Network tests are only run, if data is already locally available,
|
||
|
or if download is specifically requested by environment variable."""
|
||
|
import builtins
|
||
|
from os import environ
|
||
|
import pytest
|
||
|
from sklearn.datasets import fetch_20newsgroups
|
||
|
from sklearn.datasets import fetch_20newsgroups_vectorized
|
||
|
from sklearn.datasets import fetch_california_housing
|
||
|
from sklearn.datasets import fetch_covtype
|
||
|
from sklearn.datasets import fetch_kddcup99
|
||
|
from sklearn.datasets import fetch_olivetti_faces
|
||
|
from sklearn.datasets import fetch_rcv1
|
||
|
|
||
|
|
||
|
def _wrapped_fetch(f, dataset_name):
|
||
|
""" Fetch dataset (download if missing and requested by environment) """
|
||
|
download_if_missing = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0'
|
||
|
|
||
|
def wrapped(*args, **kwargs):
|
||
|
kwargs['download_if_missing'] = download_if_missing
|
||
|
try:
|
||
|
return f(*args, **kwargs)
|
||
|
except IOError:
|
||
|
pytest.skip("Download {} to run this test".format(dataset_name))
|
||
|
return wrapped
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def fetch_20newsgroups_fxt():
|
||
|
return _wrapped_fetch(fetch_20newsgroups, dataset_name='20newsgroups')
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def fetch_20newsgroups_vectorized_fxt():
|
||
|
return _wrapped_fetch(fetch_20newsgroups_vectorized,
|
||
|
dataset_name='20newsgroups_vectorized')
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def fetch_california_housing_fxt():
|
||
|
return _wrapped_fetch(fetch_california_housing,
|
||
|
dataset_name='california_housing')
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def fetch_covtype_fxt():
|
||
|
return _wrapped_fetch(fetch_covtype, dataset_name='covtype')
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def fetch_kddcup99_fxt():
|
||
|
return _wrapped_fetch(fetch_kddcup99, dataset_name='kddcup99')
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def fetch_olivetti_faces_fxt():
|
||
|
return _wrapped_fetch(fetch_olivetti_faces, dataset_name='olivetti_faces')
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def fetch_rcv1_fxt():
|
||
|
return _wrapped_fetch(fetch_rcv1, dataset_name='rcv1')
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def hide_available_pandas(monkeypatch):
|
||
|
""" Pretend pandas was not installed. """
|
||
|
import_orig = builtins.__import__
|
||
|
|
||
|
def mocked_import(name, *args, **kwargs):
|
||
|
if name == 'pandas':
|
||
|
raise ImportError()
|
||
|
return import_orig(name, *args, **kwargs)
|
||
|
|
||
|
monkeypatch.setattr(builtins, '__import__', mocked_import)
|