From b439c4f49aec278d60cb8beff0b72677ffdf0fae Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 20 Dec 2024 13:18:04 +0100 Subject: [PATCH] MAINT use public import for metadata routing --- imblearn/base.py | 7 +++++++ imblearn/pipeline.py | 14 ++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/imblearn/base.py b/imblearn/base.py index e837ee8da..39141029a 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -9,12 +9,19 @@ import numpy as np from sklearn.base import BaseEstimator, OneToOneFeatureMixin from sklearn.preprocessing import label_binarize +from sklearn.utils._metadata_requests import METHODS from sklearn.utils.multiclass import check_classification_targets from .utils import check_sampling_strategy, check_target_type from .utils._sklearn_compat import _fit_context, get_tags, validate_data from .utils._validation import ArraysTransformer +if "fit_predict" not in METHODS: + METHODS.append("fit_predict") +if "fit_transform" not in METHODS: + METHODS.append("fit_transform") +METHODS.append("fit_resample") + class SamplerMixin(metaclass=ABCMeta): """Mixin class for samplers with abstract method. diff --git a/imblearn/pipeline.py b/imblearn/pipeline.py index 0ef7288b8..b82b9a543 100644 --- a/imblearn/pipeline.py +++ b/imblearn/pipeline.py @@ -21,18 +21,18 @@ from sklearn.base import clone from sklearn.exceptions import NotFittedError from sklearn.utils import Bunch -from sklearn.utils._metadata_requests import ( - METHODS, +from sklearn.utils._param_validation import HasMethods +from sklearn.utils.fixes import parse_version +from sklearn.utils.metadata_routing import ( MetadataRouter, MethodMapping, _routing_enabled, get_routing_for_object, ) -from sklearn.utils._param_validation import HasMethods -from sklearn.utils.fixes import parse_version from sklearn.utils.metaestimators import available_if from sklearn.utils.validation import check_is_fitted, check_memory +from .base import METHODS from .utils._sklearn_compat import ( _fit_context, _print_elapsed_time, @@ -43,12 +43,6 @@ validate_params, ) -if "fit_predict" not in METHODS: - METHODS.append("fit_predict") -if "fit_transform" not in METHODS: - METHODS.append("fit_transform") -METHODS.append("fit_resample") - __all__ = ["Pipeline", "make_pipeline"]