Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion imblearn/utils/_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

sklearn_version = parse_version(sklearn.__version__)

if sklearn_version < parse_version("1.2"):
# if sklearn_version < parse_version("1.2"):
if True:
# TODO: remove `if True` when we have clear support for:
# - dataframe

def validate_parameter_constraints(parameter_constraints, params, caller_name):
"""Validate types and values of given parameters.
Expand All @@ -35,6 +38,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):
Constraints can be:
- an Interval object, representing a continuous or discrete range of numbers
- the string "array-like"
- the string "dataframe"
- the string "sparse matrix"
- the string "random_state"
- callable
Expand Down Expand Up @@ -115,6 +119,8 @@ def make_constraint(constraint):
return _ArrayLikes()
if isinstance(constraint, str) and constraint == "sparse matrix":
return _SparseMatrices()
if isinstance(constraint, str) and constraint == "dataframe":
return _DataFrames()
if isinstance(constraint, str) and constraint == "random_state":
return _RandomStates()
if constraint is callable:
Expand Down Expand Up @@ -466,6 +472,17 @@ def is_satisfied_by(self, val):
def __str__(self):
return "a sparse matrix"

class _DataFrames(_Constraint):
"""Constraint representing a DataFrame"""

def is_satisfied_by(self, val):
# Let's first try the dataframe protocol and then duck-typing for the older
# pandas versions.
return hasattr(val, "__dataframe__") or hasattr(val, "iloc")

def __str__(self):
return "a DataFrame"

class _Callables(_Constraint):
"""Constraint representing callables."""

Expand Down Expand Up @@ -845,6 +862,11 @@ def generate_valid_param(constraint):
if isinstance(constraint, _SparseMatrices):
return csr_matrix([[0, 1], [1, 0]])

if isinstance(constraint, _DataFrames):
import pandas as pd

return pd.DataFrame({"a": [1, 2, 3]})

if isinstance(constraint, _RandomStates):
return np.random.RandomState(42)

Expand Down
36 changes: 36 additions & 0 deletions imblearn/utils/tests/test_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_Booleans,
_Callables,
_CVObjects,
_DataFrames,
_InstancesOf,
_IterablesNotString,
_MissingValues,
Expand All @@ -36,6 +37,15 @@
)


def has_pandas():
try:
import pandas as pd

return True, pd.DataFrame({"a": [1, 2, 3]})
except ImportError:
return False, None


# Some helpers for the tests
@validate_params({"a": [Real], "b": [Real], "c": [Real], "d": [Real]})
def _func(a, b=0, *args, c, d=0, **kwargs):
Expand Down Expand Up @@ -317,6 +327,12 @@ def test_generate_invalid_param_val_2_intervals(integer_interval, real_interval)
"constraints",
[
[_ArrayLikes()],
pytest.param(
[_DataFrames()],
marks=pytest.mark.skipif(
not has_pandas()[0], reason="Pandas not installed"
),
),
[_InstancesOf(list)],
[_Callables()],
[_NoneConstraint()],
Expand All @@ -342,6 +358,12 @@ def test_generate_invalid_param_val_all_valid(constraints):
"constraint",
[
_ArrayLikes(),
pytest.param(
_DataFrames(),
marks=pytest.mark.skipif(
not has_pandas()[0], reason="Pandas not installed"
),
),
_Callables(),
_InstancesOf(list),
_NoneConstraint(),
Expand Down Expand Up @@ -381,6 +403,13 @@ def test_generate_valid_param(constraint):
(None, None),
("array-like", [[1, 2], [3, 4]]),
("array-like", np.array([[1, 2], [3, 4]])),
pytest.param(
"dataframe",
has_pandas()[1],
marks=pytest.mark.skipif(
not has_pandas()[0], reason="Pandas not installed"
),
),
("sparse matrix", csr_matrix([[1, 2], [3, 4]])),
("random_state", 0),
("random_state", np.random.RandomState(0)),
Expand Down Expand Up @@ -414,6 +443,13 @@ def test_is_satisfied_by(constraint_declaration, value):
(Options(Real, {0.42, 1.23}), Options),
("array-like", _ArrayLikes),
("sparse matrix", _SparseMatrices),
pytest.param(
"dataframe",
_DataFrames,
marks=pytest.mark.skipif(
not has_pandas()[0], reason="Pandas not installed"
),
),
("random_state", _RandomStates),
(None, _NoneConstraint),
(callable, _Callables),
Expand Down