1212from sklearn .utils .multiclass import check_classification_targets
1313
1414from .utils import check_sampling_strategy , check_target_type
15+ from .utils ._param_validation import validate_parameter_constraints
1516from .utils ._validation import ArraysTransformer
1617
1718
@@ -113,7 +114,26 @@ def _fit_resample(self, X, y):
113114 pass
114115
115116
116- class BaseSampler (SamplerMixin ):
117+ class _ParamsValidationMixin :
118+ """Mixin class to validate parameters."""
119+
120+ def _validate_params (self ):
121+ """Validate types and values of constructor parameters.
122+
123+ The expected type and values must be defined in the `_parameter_constraints`
124+ class attribute, which is a dictionary `param_name: list of constraints`. See
125+ the docstring of `validate_parameter_constraints` for a description of the
126+ accepted constraints.
127+ """
128+ if hasattr (self , "_parameter_constraints" ):
129+ validate_parameter_constraints (
130+ self ._parameter_constraints ,
131+ self .get_params (deep = False ),
132+ caller_name = self .__class__ .__name__ ,
133+ )
134+
135+
136+ class BaseSampler (SamplerMixin , _ParamsValidationMixin ):
117137 """Base class for sampling algorithms.
118138
119139 Warning: This class should not be used directly. Use the derive classes
@@ -130,6 +150,52 @@ def _check_X_y(self, X, y, accept_sparse=None):
130150 X , y = self ._validate_data (X , y , reset = True , accept_sparse = accept_sparse )
131151 return X , y , binarize_y
132152
153+ def fit (self , X , y ):
154+ """Check inputs and statistics of the sampler.
155+
156+ You should use ``fit_resample`` in all cases.
157+
158+ Parameters
159+ ----------
160+ X : {array-like, dataframe, sparse matrix} of shape \
161+ (n_samples, n_features)
162+ Data array.
163+
164+ y : array-like of shape (n_samples,)
165+ Target array.
166+
167+ Returns
168+ -------
169+ self : object
170+ Return the instance itself.
171+ """
172+ self ._validate_params ()
173+ return super ().fit (X , y )
174+
175+ def fit_resample (self , X , y ):
176+ """Resample the dataset.
177+
178+ Parameters
179+ ----------
180+ X : {array-like, dataframe, sparse matrix} of shape \
181+ (n_samples, n_features)
182+ Matrix containing the data which have to be sampled.
183+
184+ y : array-like of shape (n_samples,)
185+ Corresponding label for each sample in X.
186+
187+ Returns
188+ -------
189+ X_resampled : {array-like, dataframe, sparse matrix} of shape \
190+ (n_samples_new, n_features)
191+ The array containing the resampled data.
192+
193+ y_resampled : array-like of shape (n_samples_new,)
194+ The corresponding label of `X_resampled`.
195+ """
196+ self ._validate_params ()
197+ return super ().fit_resample (X , y )
198+
133199 def _more_tags (self ):
134200 return {"X_types" : ["2darray" , "sparse" , "dataframe" ]}
135201
@@ -241,6 +307,13 @@ class FunctionSampler(BaseSampler):
241307
242308 _sampling_type = "bypass"
243309
310+ _parameter_constraints : dict = {
311+ "func" : [callable , None ],
312+ "accept_sparse" : ["boolean" ],
313+ "kw_args" : [dict , None ],
314+ "validate" : ["boolean" ],
315+ }
316+
244317 def __init__ (self , * , func = None , accept_sparse = True , kw_args = None , validate = True ):
245318 super ().__init__ ()
246319 self .func = func
@@ -267,6 +340,7 @@ def fit(self, X, y):
267340 self : object
268341 Return the instance itself.
269342 """
343+ self ._validate_params ()
270344 # we need to overwrite SamplerMixin.fit to bypass the validation
271345 if self .validate :
272346 check_classification_targets (y )
@@ -298,6 +372,7 @@ def fit_resample(self, X, y):
298372 y_resampled : array-like of shape (n_samples_new,)
299373 The corresponding label of `X_resampled`.
300374 """
375+ self ._validate_params ()
301376 arrays_transformer = ArraysTransformer (X , y )
302377
303378 if self .validate :
0 commit comments