166 lines
4.7 KiB
Python
166 lines
4.7 KiB
Python
import numpy as np
|
|
|
|
from ..base import BaseEstimator, ClassifierMixin
|
|
from .validation import _num_samples, check_array
|
|
|
|
|
|
class ArraySlicingWrapper:
|
|
"""
|
|
Parameters
|
|
----------
|
|
array
|
|
"""
|
|
def __init__(self, array):
|
|
self.array = array
|
|
|
|
def __getitem__(self, aslice):
|
|
return MockDataFrame(self.array[aslice])
|
|
|
|
|
|
class MockDataFrame:
|
|
"""
|
|
Parameters
|
|
----------
|
|
array
|
|
"""
|
|
# have shape and length but don't support indexing.
|
|
def __init__(self, array):
|
|
self.array = array
|
|
self.values = array
|
|
self.shape = array.shape
|
|
self.ndim = array.ndim
|
|
# ugly hack to make iloc work.
|
|
self.iloc = ArraySlicingWrapper(array)
|
|
|
|
def __len__(self):
|
|
return len(self.array)
|
|
|
|
def __array__(self, dtype=None):
|
|
# Pandas data frames also are array-like: we want to make sure that
|
|
# input validation in cross-validation does not try to call that
|
|
# method.
|
|
return self.array
|
|
|
|
def __eq__(self, other):
|
|
return MockDataFrame(self.array == other.array)
|
|
|
|
def __ne__(self, other):
|
|
return not self == other
|
|
|
|
|
|
class CheckingClassifier(ClassifierMixin, BaseEstimator):
|
|
"""Dummy classifier to test pipelining and meta-estimators.
|
|
|
|
Checks some property of X and y in fit / predict.
|
|
This allows testing whether pipelines / cross-validation or metaestimators
|
|
changed the input.
|
|
|
|
Parameters
|
|
----------
|
|
check_y
|
|
check_X
|
|
foo_param
|
|
expected_fit_params
|
|
|
|
Attributes
|
|
----------
|
|
classes_
|
|
"""
|
|
def __init__(self, check_y=None, check_X=None, foo_param=0,
|
|
expected_fit_params=None):
|
|
self.check_y = check_y
|
|
self.check_X = check_X
|
|
self.foo_param = foo_param
|
|
self.expected_fit_params = expected_fit_params
|
|
|
|
def fit(self, X, y, **fit_params):
|
|
"""
|
|
Fit classifier
|
|
|
|
Parameters
|
|
----------
|
|
X : array-like of shape (n_samples, n_features)
|
|
Training vector, where n_samples is the number of samples and
|
|
n_features is the number of features.
|
|
|
|
y : array-like of shape (n_samples, n_output) or (n_samples,), optional
|
|
Target relative to X for classification or regression;
|
|
None for unsupervised learning.
|
|
|
|
**fit_params : dict of string -> object
|
|
Parameters passed to the ``fit`` method of the estimator
|
|
"""
|
|
assert len(X) == len(y)
|
|
if self.check_X is not None:
|
|
assert self.check_X(X)
|
|
if self.check_y is not None:
|
|
assert self.check_y(y)
|
|
self.n_features_in_ = len(X)
|
|
self.classes_ = np.unique(check_array(y, ensure_2d=False,
|
|
allow_nd=True))
|
|
if self.expected_fit_params:
|
|
missing = set(self.expected_fit_params) - set(fit_params)
|
|
assert len(missing) == 0, 'Expected fit parameter(s) %s not ' \
|
|
'seen.' % list(missing)
|
|
for key, value in fit_params.items():
|
|
assert len(value) == len(X), (
|
|
'Fit parameter %s has length %d; '
|
|
'expected %d.'
|
|
% (key, len(value), len(X)))
|
|
|
|
return self
|
|
|
|
def predict(self, T):
|
|
"""
|
|
Parameters
|
|
----------
|
|
T : indexable, length n_samples
|
|
"""
|
|
if self.check_X is not None:
|
|
assert self.check_X(T)
|
|
return self.classes_[np.zeros(_num_samples(T), dtype=np.int)]
|
|
|
|
def score(self, X=None, Y=None):
|
|
"""
|
|
Parameters
|
|
----------
|
|
X : array-like of shape (n_samples, n_features)
|
|
Input data, where n_samples is the number of samples and
|
|
n_features is the number of features.
|
|
|
|
Y : array-like of shape (n_samples, n_output) or (n_samples,), optional
|
|
Target relative to X for classification or regression;
|
|
None for unsupervised learning.
|
|
"""
|
|
if self.foo_param > 1:
|
|
score = 1.
|
|
else:
|
|
score = 0.
|
|
return score
|
|
|
|
def _more_tags(self):
|
|
return {'_skip_test': True, 'X_types': ['1dlabel']}
|
|
|
|
|
|
class NoSampleWeightWrapper(BaseEstimator):
|
|
"""Wrap estimator which will not expose `sample_weight`.
|
|
|
|
Parameters
|
|
----------
|
|
est : estimator, default=None
|
|
The estimator to wrap.
|
|
"""
|
|
def __init__(self, est=None):
|
|
self.est = est
|
|
|
|
def fit(self, X, y):
|
|
return self.est.fit(X, y)
|
|
|
|
def predict(self, X):
|
|
return self.est.predict(X)
|
|
|
|
def predict_proba(self, X):
|
|
return self.est.predict_proba(X)
|
|
|
|
def _more_tags(self):
|
|
return {'_skip_test': True}
|