@@ -455,12 +455,18 @@ class array : public buffer {
455455
456456 array () : array(0 , static_cast <const double *>(nullptr )) {}
457457
458- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
459- const std::vector<size_t > &strides, const void *ptr = nullptr ,
460- handle base = handle()) {
461- auto & api = detail::npy_api::get ();
462- auto ndim = shape.size ();
463- if (shape.size () != strides.size ())
458+ using ShapeContainer = detail::any_container<Py_intptr_t>;
459+ using StridesContainer = detail::any_container<Py_intptr_t>;
460+
461+ // Constructs an array taking shape/strides from arbitrary container types
462+ array (const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
463+ const void *ptr = nullptr , handle base = handle()) {
464+
465+ if (strides->empty ())
466+ strides = default_strides (*shape, dt.itemsize ());
467+
468+ auto ndim = shape->size ();
469+ if (ndim != strides->size ())
464470 pybind11_fail (" NumPy: shape ndim doesn't match strides ndim" );
465471 auto descr = dt;
466472
@@ -474,10 +480,9 @@ class array : public buffer {
474480 flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
475481 }
476482
483+ auto &api = detail::npy_api::get ();
477484 auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_ (
478- api.PyArray_Type_ , descr.release ().ptr (), (int ) ndim,
479- reinterpret_cast <Py_intptr_t *>(const_cast <size_t *>(shape.data ())),
480- reinterpret_cast <Py_intptr_t *>(const_cast <size_t *>(strides.data ())),
485+ api.PyArray_Type_ , descr.release ().ptr (), (int ) ndim, shape->data (), strides->data (),
481486 const_cast <void *>(ptr), flags, nullptr ));
482487 if (!tmp)
483488 pybind11_fail (" NumPy: unable to create array!" );
@@ -491,27 +496,37 @@ class array : public buffer {
491496 m_ptr = tmp.release ().ptr ();
492497 }
493498
494- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
495- const void *ptr = nullptr , handle base = handle())
496- : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
499+ template <typename ShapeIt, typename StridesIt,
500+ typename = detail::enable_if_t <detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
501+ array (const pybind11::dtype &dt, ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
502+ const void *ptr = nullptr , handle base = handle())
503+ : array(dt, {shape_first, shape_last}, {strides_first, strides_last}, ptr, base) { }
504+
505+ array (const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr , handle base = handle())
506+ : array(dt, std::move(shape), {}, ptr, base) { }
497507
498508 array (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
499509 handle base = handle())
500- : array(dt, std::vector< size_t >{ count }, ptr, base) { }
510+ : array(dt, ShapeContainer{{ count } }, ptr, base) { }
501511
502- template <typename T> array (const std::vector<size_t >& shape,
503- const std::vector<size_t >& strides,
504- const T* ptr, handle base = handle())
505- : array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
512+ template <typename T, typename ShapeIt, typename StridesIt,
513+ typename = detail::enable_if_t <detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
514+ array (ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
515+ const T *ptr = nullptr , handle base = handle())
516+ : array(pybind11::dtype::of<T>(), ShapeContainer(std::move(shape_first), std::move(shape_last)),
517+ StrideContainer (std::move(strides_first), std::move(strides_last)), ptr, base) { }
506518
507519 template <typename T>
508- array (const std::vector<size_t > &shape, const T *ptr,
509- handle base = handle())
510- : array(shape, default_strides(shape, sizeof (T)), ptr, base) { }
520+ array (ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
521+ : array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
522+
523+ template <typename T>
524+ array (ShapeContainer shape, const T *ptr, handle base = handle())
525+ : array(std::move(shape), {}, ptr, base) { }
511526
512527 template <typename T>
513528 array (size_t count, const T *ptr, handle base = handle())
514- : array(std::vector< size_t >{ count }, ptr, base) { }
529+ : array({{ count } }, ptr, base) { }
515530
516531 explicit array (const buffer_info &info)
517532 : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
@@ -673,9 +688,9 @@ class array : public buffer {
673688 throw std::domain_error (" array is not writeable" );
674689 }
675690
676- static std::vector<size_t > default_strides (const std::vector<size_t >& shape, size_t itemsize) {
691+ static std::vector<Py_intptr_t > default_strides (const std::vector<Py_intptr_t >& shape, size_t itemsize) {
677692 auto ndim = shape.size ();
678- std::vector<size_t > strides (ndim);
693+ std::vector<Py_intptr_t > strides (ndim);
679694 if (ndim) {
680695 std::fill (strides.begin (), strides.end (), itemsize);
681696 for (size_t i = 0 ; i < ndim - 1 ; i++)
@@ -729,14 +744,18 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
729744
730745 explicit array_t (const buffer_info& info) : array(info) { }
731746
732- array_t (const std::vector<size_t > &shape,
733- const std::vector<size_t > &strides, const T *ptr = nullptr ,
734- handle base = handle())
735- : array(shape, strides, ptr, base) { }
747+ array_t (ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr , handle base = handle())
748+ : array(std::move(shape), std::move(strides), ptr, base) { }
749+
750+ template <typename ShapeIt, typename StridesIt,
751+ typename = detail::enable_if_t <detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
752+ array_t (ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
753+ const T *ptr = nullptr , handle base = handle())
754+ : array(ShapeContainer(std::move(shape_first), std::move(shape_last)),
755+ StridesContainer (std::move(strides_first), std::move(strides_last)), ptr, base) { }
736756
737- explicit array_t (const std::vector<size_t > &shape, const T *ptr = nullptr ,
738- handle base = handle())
739- : array(shape, ptr, base) { }
757+ explicit array_t (ShapeContainer shape, const T *ptr = nullptr , handle base = handle())
758+ : array(std::move(shape), ptr, base) { }
740759
741760 explicit array_t (size_t count, const T *ptr = nullptr , handle base = handle())
742761 : array(count, ptr, base) { }
0 commit comments