Skip to content

Commit ecfaa8c

Browse files
committed
numpy sum
1 parent 9ae63ae commit ecfaa8c

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/test_stats.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,10 @@ def test_sum_benchmark(
9696
except NotImplementedError:
9797
pytest.skip("random_array not implemented for dtype")
9898

99-
stats.sum(arr, axis=axis) # type: ignore[arg-type] # warmup: numba compile
100-
benchmark(stats.sum, arr, axis=axis)
99+
def sum(arr, axis):
100+
if hasattr(arr, sum):
101+
return arr.sum(axis=axis)
102+
return np.sum(arr, axis=axis)
103+
104+
sum(arr, axis=axis) # type: ignore[arg-type] # warmup: numba compile
105+
benchmark(sum, arr, axis=axis)

0 commit comments

Comments
 (0)