Skip to content

Commit 416898b

Browse files
thomasjpfanogrisel
andauthored
ENH Adds Column name consistency (#18010)
Co-authored-by: Olivier Grisel <[email protected]>
1 parent c592361 commit 416898b

File tree

16 files changed

+462
-7
lines changed

16 files changed

+462
-7
lines changed

doc/whats_new/v1.0.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ Changelog
134134
- |API| `np.matrix` usage is deprecated in 1.0 and will raise a `TypeError` in
135135
1.2. :pr:`20165` by `Thomas Fan`_.
136136

137+
- |API| All estimators store `feature_names_in_` when fitted on pandas Dataframes.
138+
These feature names are compared to names seen in `non-fit` methods,
139+
`i.e.` `transform` and will raise a `FutureWarning` if they are not consistent.
140+
These `FutureWarning`s will become `ValueError`s in 1.2.
141+
:pr:`18010` by `Thomas Fan`_.
142+
137143
:mod:`sklearn.base`
138144
...................
139145

sklearn/base.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .utils.validation import _check_y
2525
from .utils.validation import _num_features
2626
from .utils._estimator_html_repr import estimator_html_repr
27+
from .utils.validation import _get_feature_names
2728

2829

2930
def clone(estimator, *, safe=True):
@@ -395,6 +396,92 @@ def _check_n_features(self, X, reset):
395396
f"is expecting {self.n_features_in_} features as input."
396397
)
397398

399+
def _check_feature_names(self, X, *, reset):
400+
"""Set or check the `feature_names_in_` attribute.
401+
402+
.. versionadded:: 1.0
403+
404+
Parameters
405+
----------
406+
X : {ndarray, dataframe} of shape (n_samples, n_features)
407+
The input samples.
408+
409+
reset : bool
410+
Whether to reset the `feature_names_in_` attribute.
411+
If False, the input will be checked for consistency with
412+
feature names of data provided when reset was last True.
413+
.. note::
414+
It is recommended to call `reset=True` in `fit` and in the first
415+
call to `partial_fit`. All other methods that validate `X`
416+
should set `reset=False`.
417+
"""
418+
419+
if reset:
420+
feature_names_in = _get_feature_names(X)
421+
if feature_names_in is not None:
422+
self.feature_names_in_ = feature_names_in
423+
return
424+
425+
fitted_feature_names = getattr(self, "feature_names_in_", None)
426+
X_feature_names = _get_feature_names(X)
427+
428+
if fitted_feature_names is None and X_feature_names is None:
429+
# no feature names seen in fit and in X
430+
return
431+
432+
if X_feature_names is not None and fitted_feature_names is None:
433+
warnings.warn(
434+
f"X has feature names, but {self.__class__.__name__} was fitted without"
435+
" feature names"
436+
)
437+
return
438+
439+
if X_feature_names is None and fitted_feature_names is not None:
440+
warnings.warn(
441+
"X does not have valid feature names, but"
442+
f" {self.__class__.__name__} was fitted with feature names"
443+
)
444+
return
445+
446+
# validate the feature names against the `feature_names_in_` attribute
447+
if len(fitted_feature_names) != len(X_feature_names) or np.any(
448+
fitted_feature_names != X_feature_names
449+
):
450+
message = (
451+
"The feature names should match those that were "
452+
"passed during fit. Starting version 1.2, an error will be raised.\n"
453+
)
454+
fitted_feature_names_set = set(fitted_feature_names)
455+
X_feature_names_set = set(X_feature_names)
456+
457+
unexpected_names = sorted(X_feature_names_set - fitted_feature_names_set)
458+
missing_names = sorted(fitted_feature_names_set - X_feature_names_set)
459+
460+
def add_names(names):
461+
output = ""
462+
max_n_names = 5
463+
for i, name in enumerate(names):
464+
if i >= max_n_names:
465+
output += "- ...\n"
466+
break
467+
output += f"- {name}\n"
468+
return output
469+
470+
if unexpected_names:
471+
message += "Feature names unseen at fit time:\n"
472+
message += add_names(unexpected_names)
473+
474+
if missing_names:
475+
message += "Feature names seen at fit time, yet now missing:\n"
476+
message += add_names(missing_names)
477+
478+
if not missing_names and not missing_names:
479+
message += (
480+
"Feature names must be in the same order as they were in fit.\n"
481+
)
482+
483+
warnings.warn(message, FutureWarning)
484+
398485
def _validate_data(
399486
self,
400487
X="no_validation",
@@ -452,6 +539,8 @@ def _validate_data(
452539
The validated input. A tuple is returned if both `X` and `y` are
453540
validated.
454541
"""
542+
self._check_feature_names(X, reset=reset)
543+
455544
if y is None and self._get_tags()["requires_y"]:
456545
raise ValueError(
457546
f"This {self.__class__.__name__} estimator "

sklearn/calibration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ def fit(self, X, y, sample_weight=None):
368368
first_clf = self.calibrated_classifiers_[0].base_estimator
369369
if hasattr(first_clf, "n_features_in_"):
370370
self.n_features_in_ = first_clf.n_features_in_
371+
if hasattr(first_clf, "feature_names_in_"):
372+
self.feature_names_in_ = first_clf.feature_names_in_
371373
return self
372374

373375
def predict_proba(self, X):

sklearn/feature_selection/_from_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ def fit(self, X, y=None, **fit_params):
257257
raise NotFittedError("Since 'prefit=True', call transform directly")
258258
self.estimator_ = clone(self.estimator)
259259
self.estimator_.fit(X, y, **fit_params)
260+
if hasattr(self.estimator_, "feature_names_in_"):
261+
self.feature_names_in_ = self.estimator_.feature_names_in_
260262
return self
261263

262264
@property

sklearn/kernel_approximation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from .base import BaseEstimator
2323
from .base import TransformerMixin
24-
from .utils import check_random_state, as_float_array
24+
from .utils import check_random_state
2525
from .utils.extmath import safe_sparse_dot
2626
from .utils.validation import check_is_fitted
2727
from .metrics.pairwise import pairwise_kernels, KERNEL_PARAMS
@@ -469,9 +469,9 @@ def transform(self, X):
469469
Returns the instance itself.
470470
"""
471471
check_is_fitted(self)
472-
473-
X = as_float_array(X, copy=True)
474-
X = self._validate_data(X, copy=False, reset=False)
472+
X = self._validate_data(
473+
X, copy=True, dtype=[np.float64, np.float32], reset=False
474+
)
475475
if (X <= -self.skewedness).any():
476476
raise ValueError("X may not contain entries smaller than -skewedness.")
477477

sklearn/linear_model/_ransac.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ def predict(self, X):
556556
Returns predicted values.
557557
"""
558558
check_is_fitted(self)
559+
self._check_feature_names(X, reset=False)
559560

560561
return self.estimator_.predict(X)
561562

@@ -578,6 +579,7 @@ def score(self, X, y):
578579
Score of the prediction.
579580
"""
580581
check_is_fitted(self)
582+
self._check_feature_names(X, reset=False)
581583

582584
return self.estimator_.score(X, y)
583585

sklearn/linear_model/_ridge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,6 +1983,8 @@ def fit(self, X, y, sample_weight=None):
19831983
self.coef_ = estimator.coef_
19841984
self.intercept_ = estimator.intercept_
19851985
self.n_features_in_ = estimator.n_features_in_
1986+
if hasattr(estimator, "feature_names_in_"):
1987+
self.feature_names_in_ = estimator.feature_names_in_
19861988

19871989
return self
19881990

sklearn/manifold/_isomap.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def _fit_transform(self, X):
172172
)
173173
self.nbrs_.fit(X)
174174
self.n_features_in_ = self.nbrs_.n_features_in_
175+
if hasattr(self.nbrs_, "feature_names_in_"):
176+
self.feature_names_in_ = self.nbrs_.feature_names_in_
175177

176178
self.kernel_pca_ = KernelPCA(
177179
n_components=self.n_components,

sklearn/manifold/_locally_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ def transform(self, X):
768768
"""
769769
check_is_fitted(self)
770770

771-
X = check_array(X)
771+
X = self._validate_data(X, reset=False)
772772
ind = self.nbrs_.kneighbors(
773773
X, n_neighbors=self.n_neighbors, return_distance=False
774774
)

sklearn/neural_network/_rbm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from ..base import BaseEstimator
1717
from ..base import TransformerMixin
18-
from ..utils import check_array
1918
from ..utils import check_random_state
2019
from ..utils import gen_even_slices
2120
from ..utils.extmath import safe_sparse_dot
@@ -333,7 +332,7 @@ def score_samples(self, X):
333332
"""
334333
check_is_fitted(self)
335334

336-
v = check_array(X, accept_sparse="csr")
335+
v = self._validate_data(X, accept_sparse="csr", reset=False)
337336
rng = check_random_state(self.random_state)
338337

339338
# Randomly corrupt one feature in each sample in v.

0 commit comments

Comments
 (0)