Skip to content

Commit 9df1383

Browse files
Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags (#2484)
* Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags * Add trivially-contiguous arrays to the tests
1 parent f12ec00 commit 9df1383

File tree

3 files changed

+86
-1
lines changed

3 files changed

+86
-1
lines changed

include/pybind11/numpy.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,8 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
934934
static bool check_(handle h) {
935935
const auto &api = detail::npy_api::get();
936936
return api.PyArray_Check_(h.ptr())
937-
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
937+
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr())
938+
&& detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
938939
}
939940

940941
protected:

tests/test_numpy_array.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,4 +385,42 @@ TEST_SUBMODULE(numpy_array, sm) {
385385
sm.def("index_using_ellipsis", [](py::array a) {
386386
return a[py::make_tuple(0, py::ellipsis(), 0)];
387387
});
388+
389+
// test_argument_conversions
390+
sm.def("accept_double",
391+
[](py::array_t<double, 0>) {},
392+
py::arg("a"));
393+
sm.def("accept_double_forcecast",
394+
[](py::array_t<double, py::array::forcecast>) {},
395+
py::arg("a"));
396+
sm.def("accept_double_c_style",
397+
[](py::array_t<double, py::array::c_style>) {},
398+
py::arg("a"));
399+
sm.def("accept_double_c_style_forcecast",
400+
[](py::array_t<double, py::array::forcecast | py::array::c_style>) {},
401+
py::arg("a"));
402+
sm.def("accept_double_f_style",
403+
[](py::array_t<double, py::array::f_style>) {},
404+
py::arg("a"));
405+
sm.def("accept_double_f_style_forcecast",
406+
[](py::array_t<double, py::array::forcecast | py::array::f_style>) {},
407+
py::arg("a"));
408+
sm.def("accept_double_noconvert",
409+
[](py::array_t<double, 0>) {},
410+
py::arg("a").noconvert());
411+
sm.def("accept_double_forcecast_noconvert",
412+
[](py::array_t<double, py::array::forcecast>) {},
413+
py::arg("a").noconvert());
414+
sm.def("accept_double_c_style_noconvert",
415+
[](py::array_t<double, py::array::c_style>) {},
416+
py::arg("a").noconvert());
417+
sm.def("accept_double_c_style_forcecast_noconvert",
418+
[](py::array_t<double, py::array::forcecast | py::array::c_style>) {},
419+
py::arg("a").noconvert());
420+
sm.def("accept_double_f_style_noconvert",
421+
[](py::array_t<double, py::array::f_style>) {},
422+
py::arg("a").noconvert());
423+
sm.def("accept_double_f_style_forcecast_noconvert",
424+
[](py::array_t<double, py::array::forcecast | py::array::f_style>) {},
425+
py::arg("a").noconvert());
388426
}

tests/test_numpy_array.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,52 @@ def test_index_using_ellipsis():
435435
assert a.shape == (6,)
436436

437437

438+
@pytest.mark.parametrize("forcecast", [False, True])
439+
@pytest.mark.parametrize("contiguity", [None, 'C', 'F'])
440+
@pytest.mark.parametrize("noconvert", [False, True])
441+
@pytest.mark.filterwarnings(
442+
"ignore:Casting complex values to real discards the imaginary part:numpy.ComplexWarning"
443+
)
444+
def test_argument_conversions(forcecast, contiguity, noconvert):
445+
function_name = "accept_double"
446+
if contiguity == 'C':
447+
function_name += "_c_style"
448+
elif contiguity == 'F':
449+
function_name += "_f_style"
450+
if forcecast:
451+
function_name += "_forcecast"
452+
if noconvert:
453+
function_name += "_noconvert"
454+
function = getattr(m, function_name)
455+
456+
for dtype in [np.dtype('float32'), np.dtype('float64'), np.dtype('complex128')]:
457+
for order in ['C', 'F']:
458+
for shape in [(2, 2), (1, 3, 1, 1), (1, 1, 1), (0,)]:
459+
if not noconvert:
460+
# If noconvert is not passed, only complex128 needs to be truncated and
461+
# "cannot be safely obtained". So without `forcecast`, the argument shouldn't
462+
# be accepted.
463+
should_raise = dtype.name == 'complex128' and not forcecast
464+
else:
465+
# If noconvert is passed, only float64 and the matching order is accepted.
466+
# If at most one dimension has a size greater than 1, the array is also
467+
# trivially contiguous.
468+
trivially_contiguous = sum(1 for d in shape if d > 1) <= 1
469+
should_raise = (
470+
dtype.name != 'float64' or
471+
(contiguity is not None and
472+
contiguity != order and
473+
not trivially_contiguous)
474+
)
475+
476+
array = np.zeros(shape, dtype=dtype, order=order)
477+
if not should_raise:
478+
function(array)
479+
else:
480+
with pytest.raises(TypeError, match="incompatible function arguments"):
481+
function(array)
482+
483+
438484
@pytest.mark.xfail("env.PYPY")
439485
def test_dtype_refcount_leak():
440486
from sys import getrefcount

0 commit comments

Comments
 (0)