Skip to content

Commit d300769

Browse files
timhoffmmeeseeksmachine
authored andcommitted
Backport PR matplotlib#30714: FIX: Gracefully handle numpy arrays as input to check_in_list()
1 parent 799bc95 commit d300769

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

lib/matplotlib/_api/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def check_in_list(values, /, *, _print_supported_values=True, **kwargs):
106106
----------
107107
values : iterable
108108
Sequence of values to check on.
109+
110+
Note: All values must support == comparisons.
111+
This means in particular the entries must not be numpy arrays.
109112
_print_supported_values : bool, default: True
110113
Whether to print *values* when raising ValueError.
111114
**kwargs : dict
@@ -123,7 +126,18 @@ def check_in_list(values, /, *, _print_supported_values=True, **kwargs):
123126
if not kwargs:
124127
raise TypeError("No argument to check!")
125128
for key, val in kwargs.items():
126-
if val not in values:
129+
try:
130+
exists = val in values
131+
except ValueError:
132+
# `in` internally uses `val == values[i]`. There are some objects
133+
# that do not support == to arbitrary other objects, in particular
134+
# numpy arrays.
135+
# Since such objects are not allowed in values, we can gracefully
136+
# handle the case that val (typically provided by users) is of such
137+
# type and directly state it's not in the list instead of letting
138+
# the individual `val == values[i]` ValueError surface.
139+
exists = False
140+
if not exists:
127141
msg = f"{val!r} is not a valid value for {key}"
128142
if _print_supported_values:
129143
msg += f"; supported values are {', '.join(map(repr, values))}"

lib/matplotlib/tests/test_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,8 @@ def f() -> None:
150150
def test_empty_check_in_list() -> None:
151151
with pytest.raises(TypeError, match="No argument to check!"):
152152
_api.check_in_list(["a"])
153+
154+
155+
def test_check_in_list_numpy() -> None:
156+
with pytest.raises(ValueError, match=r"array\(5\) is not a valid value"):
157+
_api.check_in_list(['a', 'b'], value=np.array(5))

0 commit comments

Comments
 (0)