43 lines
1.7 KiB
Python
43 lines
1.7 KiB
Python
"""Test loaders for common functionality.
|
|
"""
|
|
import pytest
|
|
import numpy as np
|
|
|
|
|
|
def check_pandas_dependency_message(fetch_func):
|
|
try:
|
|
import pandas # noqa
|
|
pytest.skip("This test requires pandas to be not installed")
|
|
except ImportError:
|
|
# Check that pandas is imported lazily and that an informative error
|
|
# message is raised when pandas is missing:
|
|
expected_msg = ('{} with as_frame=True requires pandas'
|
|
.format(fetch_func.__name__))
|
|
with pytest.raises(ImportError, match=expected_msg):
|
|
fetch_func(as_frame=True)
|
|
|
|
|
|
def check_return_X_y(bunch, fetch_func_partial):
|
|
X_y_tuple = fetch_func_partial(return_X_y=True)
|
|
assert isinstance(X_y_tuple, tuple)
|
|
assert X_y_tuple[0].shape == bunch.data.shape
|
|
assert X_y_tuple[1].shape == bunch.target.shape
|
|
|
|
|
|
def check_as_frame(bunch, fetch_func_partial,
|
|
expected_data_dtype=None, expected_target_dtype=None):
|
|
pd = pytest.importorskip('pandas')
|
|
frame_bunch = fetch_func_partial(as_frame=True)
|
|
assert hasattr(frame_bunch, 'frame')
|
|
assert isinstance(frame_bunch.frame, pd.DataFrame)
|
|
assert isinstance(frame_bunch.data, pd.DataFrame)
|
|
assert frame_bunch.data.shape == bunch.data.shape
|
|
if frame_bunch.target.ndim > 1:
|
|
assert isinstance(frame_bunch.target, pd.DataFrame)
|
|
else:
|
|
assert isinstance(frame_bunch.target, pd.Series)
|
|
assert frame_bunch.target.shape[0] == bunch.target.shape[0]
|
|
if expected_data_dtype is not None:
|
|
assert np.all(frame_bunch.data.dtypes == expected_data_dtype)
|
|
if expected_target_dtype is not None:
|
|
assert np.all(frame_bunch.target.dtypes == expected_target_dtype)
|