@@ -305,14 +305,16 @@ class dtype : public object {
305305
306306class array : public buffer {
307307public:
308- PYBIND11_OBJECT_DEFAULT (array, buffer, detail::npy_api::get().PyArray_Check_)
308+ PYBIND11_OBJECT_CVT (array, buffer, detail::npy_api::get().PyArray_Check_, raw_array )
309309
310310 enum {
311311 c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
312312 f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
313313 forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
314314 };
315315
316+ array () : array(0 , static_cast <const double *>(nullptr )) {}
317+
316318 array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
317319 const std::vector<size_t > &strides, const void *ptr = nullptr ,
318320 handle base = handle()) {
@@ -478,10 +480,12 @@ class array : public buffer {
478480 }
479481
480482 // / Ensure that the argument is a NumPy array
481- static array ensure (object input, int ExtraFlags = 0 ) {
482- auto & api = detail::npy_api::get ();
483- return reinterpret_steal<array>(api.PyArray_FromAny_ (
484- input.release ().ptr (), nullptr , 0 , 0 , detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr ));
483+ // / In case of an error, nullptr is returned and the Python error is cleared.
484+ static array ensure (handle h, int ExtraFlags = 0 ) {
485+ auto result = reinterpret_steal<array>(raw_array (h.ptr (), ExtraFlags));
486+ if (!result)
487+ PyErr_Clear ();
488+ return result;
485489 }
486490
487491protected:
@@ -520,8 +524,6 @@ class array : public buffer {
520524 return strides;
521525 }
522526
523- protected:
524-
525527 template <typename ... Ix> void check_dimensions (Ix... index) const {
526528 check_dimensions_impl (size_t (0 ), shape (), size_t (index)...);
527529 }
@@ -536,15 +538,31 @@ class array : public buffer {
536538 }
537539 check_dimensions_impl (axis + 1 , shape + 1 , index...);
538540 }
541+
542+ // / Create array from any object -- always returns a new reference
543+ static PyObject *raw_array (PyObject *ptr, int ExtraFlags = 0 ) {
544+ if (ptr == nullptr )
545+ return nullptr ;
546+ return detail::npy_api::get ().PyArray_FromAny_ (
547+ ptr, nullptr , 0 , 0 , detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
548+ }
539549};
540550
541551template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
542552public:
543- array_t () : array() { }
553+ array_t () : array(0 , static_cast <const T *>(nullptr )) {}
554+ array_t (handle h, borrowed_t ) : array(h, borrowed) { }
555+ array_t (handle h, stolen_t ) : array(h, stolen) { }
544556
545- array_t (handle h, bool is_borrowed) : array(h, is_borrowed) { m_ptr = ensure_ (m_ptr); }
557+ PYBIND11_DEPRECATED (" Use array_t<T>::ensure() instead" )
558+ array_t (handle h, bool is_borrowed) : array(raw_array_t (h.ptr()), stolen) {
559+ if (!m_ptr) PyErr_Clear ();
560+ if (!is_borrowed) Py_XDECREF (h.ptr ());
561+ }
546562
547- array_t (const object &o) : array(o) { m_ptr = ensure_ (m_ptr); }
563+ array_t (const object &o) : array(raw_array_t (o.ptr()), stolen) {
564+ if (!m_ptr) throw error_already_set ();
565+ }
548566
549567 explicit array_t (const buffer_info& info) : array(info) { }
550568
@@ -590,17 +608,30 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
590608 return *(static_cast <T*>(array::mutable_data ()) + byte_offset (size_t (index)...) / itemsize ());
591609 }
592610
593- static PyObject *ensure_ (PyObject *ptr) {
594- if (ptr == nullptr )
595- return nullptr ;
596- auto & api = detail::npy_api::get ();
597- PyObject *result = api.PyArray_FromAny_ (ptr, pybind11::dtype::of<T>().release ().ptr (), 0 , 0 ,
598- detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
611+ // / Ensure that the argument is a NumPy array of the correct dtype.
612+ // / In case of an error, nullptr is returned and the Python error is cleared.
613+ static array_t ensure (handle h) {
614+ auto result = reinterpret_steal<array_t >(raw_array_t (h.ptr ()));
599615 if (!result)
600616 PyErr_Clear ();
601- Py_DECREF (ptr);
602617 return result;
603618 }
619+
620+ static bool _check (handle h) {
621+ const auto &api = detail::npy_api::get ();
622+ return api.PyArray_Check_ (h.ptr ())
623+ && api.PyArray_EquivTypes_ (PyArray_GET_ (h.ptr (), descr), dtype::of<T>().ptr ());
624+ }
625+
626+ protected:
627+ // / Create array from any object -- always returns a new reference
628+ static PyObject *raw_array_t (PyObject *ptr) {
629+ if (ptr == nullptr )
630+ return nullptr ;
631+ return detail::npy_api::get ().PyArray_FromAny_ (
632+ ptr, dtype::of<T>().release ().ptr (), 0 , 0 ,
633+ detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
634+ }
604635};
605636
606637template <typename T>
@@ -631,7 +662,7 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
631662 using type = array_t <T, ExtraFlags>;
632663
633664 bool load (handle src, bool /* convert */ ) {
634- value = type (src, true );
665+ value = type::ensure (src);
635666 return static_cast <bool >(value);
636667 }
637668
0 commit comments