Skip to content

Commit 5c9d25b

Browse files
authored
allow customizing the expected array type (#38)
* convert the `array_type` property to a static method This will allow choosing the return type depending on the operation. * call `array_type` with `op` * change the numpy tests, as well
1 parent 517c318 commit 5c9d25b

File tree

5 files changed

+15
-12
lines changed

5 files changed

+15
-12
lines changed

xarray_array_testing/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@ class DuckArrayTestMixin(ABC):
1212
def xp() -> ModuleType:
1313
pass
1414

15-
@property
16-
@abc.abstractmethod
17-
def array_type(self) -> type[duckarray]:
15+
@staticmethod
16+
def array_type(op: str) -> type[duckarray]:
1817
pass
1918

2019
@staticmethod

xarray_array_testing/creation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ class CreationTests(DuckArrayTestMixin):
1010
def test_create_variable(self, data):
1111
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
1212

13-
assert isinstance(variable.data, self.array_type)
13+
assert isinstance(variable.data, self.array_type("__init__"))

xarray_array_testing/indexing.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def test_variable_isel_orthogonal(self, data):
9595
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
9696
expected = variable.data[*raw_indexers.values()]
9797

98-
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
98+
assert isinstance(
99+
actual, self.array_type("orthogonal_indexing")
100+
), f"wrong type: {type(actual)}"
99101
self.assert_equal(actual, expected)
100102

101103
@given(st.data())
@@ -109,5 +111,7 @@ def test_variable_isel_vectorized(self, data):
109111
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
110112
expected = variable.data[*raw_indexers.values()]
111113

112-
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
114+
assert isinstance(
115+
actual, self.array_type("vectorized_indexing")
116+
), f"wrong type: {type(actual)}"
113117
self.assert_equal(actual, expected)

xarray_array_testing/reduction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_variable_numerical_reduce(self, op, data):
2525
# compute using xp.<OP>(array)
2626
expected = getattr(self.xp, op)(variable.data)
2727

28-
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
28+
assert isinstance(actual, self.array_type(op)), f"wrong type: {type(actual)}"
2929
self.assert_equal(actual, expected)
3030

3131
@pytest.mark.parametrize("op", ["all", "any"])
@@ -39,7 +39,7 @@ def test_variable_boolean_reduce(self, op, data):
3939
# compute using xp.<OP>(array)
4040
expected = getattr(self.xp, op)(variable.data)
4141

42-
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
42+
assert isinstance(actual, self.array_type(op)), f"wrong type: {type(actual)}"
4343
self.assert_equal(actual, expected)
4444

4545
@pytest.mark.parametrize("op", ["max", "min"])
@@ -53,7 +53,7 @@ def test_variable_order_reduce(self, op, data):
5353
# compute using xp.<OP>(array)
5454
expected = getattr(self.xp, op)(variable.data)
5555

56-
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
56+
assert isinstance(actual, self.array_type(op)), f"wrong type: {type(actual)}"
5757
self.assert_equal(actual, expected)
5858

5959
@pytest.mark.parametrize("op", ["argmax", "argmin"])
@@ -96,5 +96,5 @@ def test_variable_cumulative_reduce(self, op, data):
9696
for axis in range(variable.ndim):
9797
expected = getattr(self.xp, array_api_names[op])(expected, axis=axis)
9898

99-
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
99+
assert isinstance(actual, self.array_type(op)), f"wrong type: {type(actual)}"
100100
self.assert_equal(actual, expected)

xarray_array_testing/tests/test_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ class NumpyTestMixin(DuckArrayTestMixin):
1818
def xp(self) -> ModuleType:
1919
return np
2020

21-
@property
22-
def array_type(self) -> type[np.ndarray]:
21+
@staticmethod
22+
def array_type(op: str) -> type[np.ndarray]:
2323
return np.ndarray
2424

2525
@staticmethod

0 commit comments

Comments
 (0)