|
24 | 24 | from .utils.validation import _check_y
|
25 | 25 | from .utils.validation import _num_features
|
26 | 26 | from .utils._estimator_html_repr import estimator_html_repr
|
| 27 | +from .utils.validation import _get_feature_names |
27 | 28 |
|
28 | 29 |
|
29 | 30 | def clone(estimator, *, safe=True):
|
@@ -395,6 +396,92 @@ def _check_n_features(self, X, reset):
|
395 | 396 | f"is expecting {self.n_features_in_} features as input."
|
396 | 397 | )
|
397 | 398 |
|
| 399 | + def _check_feature_names(self, X, *, reset): |
| 400 | + """Set or check the `feature_names_in_` attribute. |
| 401 | +
|
| 402 | + .. versionadded:: 1.0 |
| 403 | +
|
| 404 | + Parameters |
| 405 | + ---------- |
| 406 | + X : {ndarray, dataframe} of shape (n_samples, n_features) |
| 407 | + The input samples. |
| 408 | +
|
| 409 | + reset : bool |
| 410 | + Whether to reset the `feature_names_in_` attribute. |
| 411 | + If False, the input will be checked for consistency with |
| 412 | + feature names of data provided when reset was last True. |
| 413 | + .. note:: |
| 414 | + It is recommended to call `reset=True` in `fit` and in the first |
| 415 | + call to `partial_fit`. All other methods that validate `X` |
| 416 | + should set `reset=False`. |
| 417 | + """ |
| 418 | + |
| 419 | + if reset: |
| 420 | + feature_names_in = _get_feature_names(X) |
| 421 | + if feature_names_in is not None: |
| 422 | + self.feature_names_in_ = feature_names_in |
| 423 | + return |
| 424 | + |
| 425 | + fitted_feature_names = getattr(self, "feature_names_in_", None) |
| 426 | + X_feature_names = _get_feature_names(X) |
| 427 | + |
| 428 | + if fitted_feature_names is None and X_feature_names is None: |
| 429 | + # no feature names seen in fit and in X |
| 430 | + return |
| 431 | + |
| 432 | + if X_feature_names is not None and fitted_feature_names is None: |
| 433 | + warnings.warn( |
| 434 | + f"X has feature names, but {self.__class__.__name__} was fitted without" |
| 435 | + " feature names" |
| 436 | + ) |
| 437 | + return |
| 438 | + |
| 439 | + if X_feature_names is None and fitted_feature_names is not None: |
| 440 | + warnings.warn( |
| 441 | + "X does not have valid feature names, but" |
| 442 | + f" {self.__class__.__name__} was fitted with feature names" |
| 443 | + ) |
| 444 | + return |
| 445 | + |
| 446 | + # validate the feature names against the `feature_names_in_` attribute |
| 447 | + if len(fitted_feature_names) != len(X_feature_names) or np.any( |
| 448 | + fitted_feature_names != X_feature_names |
| 449 | + ): |
| 450 | + message = ( |
| 451 | + "The feature names should match those that were " |
| 452 | + "passed during fit. Starting version 1.2, an error will be raised.\n" |
| 453 | + ) |
| 454 | + fitted_feature_names_set = set(fitted_feature_names) |
| 455 | + X_feature_names_set = set(X_feature_names) |
| 456 | + |
| 457 | + unexpected_names = sorted(X_feature_names_set - fitted_feature_names_set) |
| 458 | + missing_names = sorted(fitted_feature_names_set - X_feature_names_set) |
| 459 | + |
| 460 | + def add_names(names): |
| 461 | + output = "" |
| 462 | + max_n_names = 5 |
| 463 | + for i, name in enumerate(names): |
| 464 | + if i >= max_n_names: |
| 465 | + output += "- ...\n" |
| 466 | + break |
| 467 | + output += f"- {name}\n" |
| 468 | + return output |
| 469 | + |
| 470 | + if unexpected_names: |
| 471 | + message += "Feature names unseen at fit time:\n" |
| 472 | + message += add_names(unexpected_names) |
| 473 | + |
| 474 | + if missing_names: |
| 475 | + message += "Feature names seen at fit time, yet now missing:\n" |
| 476 | + message += add_names(missing_names) |
| 477 | + |
| 478 | + if not missing_names and not missing_names: |
| 479 | + message += ( |
| 480 | + "Feature names must be in the same order as they were in fit.\n" |
| 481 | + ) |
| 482 | + |
| 483 | + warnings.warn(message, FutureWarning) |
| 484 | + |
398 | 485 | def _validate_data(
|
399 | 486 | self,
|
400 | 487 | X="no_validation",
|
@@ -452,6 +539,8 @@ def _validate_data(
|
452 | 539 | The validated input. A tuple is returned if both `X` and `y` are
|
453 | 540 | validated.
|
454 | 541 | """
|
| 542 | + self._check_feature_names(X, reset=reset) |
| 543 | + |
455 | 544 | if y is None and self._get_tags()["requires_y"]:
|
456 | 545 | raise ValueError(
|
457 | 546 | f"This {self.__class__.__name__} estimator "
|
|
0 commit comments