Skip to content

Commit ac83c33

Browse files
committed
array: implement array resize
1 parent fb50ce1 commit ac83c33

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

include/pybind11/numpy.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ struct npy_api {
129129
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
130130
};
131131

132+
typedef struct {
133+
Py_intptr_t *ptr;
134+
int len;
135+
} PyArray_Dims;
136+
132137
static npy_api& get() {
133138
static npy_api api = lookup();
134139
return api;
@@ -158,6 +163,7 @@ struct npy_api {
158163
Py_ssize_t *, PyObject **, PyObject *);
159164
PyObject *(*PyArray_Squeeze_)(PyObject *);
160165
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
166+
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
161167
private:
162168
enum functions {
163169
API_PyArray_Type = 2,
@@ -166,6 +172,7 @@ struct npy_api {
166172
API_PyArray_DescrFromType = 45,
167173
API_PyArray_DescrFromScalar = 57,
168174
API_PyArray_FromAny = 69,
175+
API_PyArray_Resize = 80,
169176
API_PyArray_NewCopy = 85,
170177
API_PyArray_NewFromDescr = 94,
171178
API_PyArray_DescrNewFromType = 9,
@@ -192,6 +199,7 @@ struct npy_api {
192199
DECL_NPY_API(PyArray_DescrFromType);
193200
DECL_NPY_API(PyArray_DescrFromScalar);
194201
DECL_NPY_API(PyArray_FromAny);
202+
DECL_NPY_API(PyArray_Resize);
195203
DECL_NPY_API(PyArray_NewCopy);
196204
DECL_NPY_API(PyArray_NewFromDescr);
197205
DECL_NPY_API(PyArray_DescrNewFromType);
@@ -647,6 +655,21 @@ class array : public buffer {
647655
return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
648656
}
649657

658+
/// Resize array to given shape
659+
/// If refcheck is true and more that one reference exist to this array
660+
/// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
661+
void resize(ShapeContainer new_shape, bool refcheck = true) {
662+
detail::npy_api::PyArray_Dims d = {
663+
new_shape->data(), int(new_shape->size())
664+
};
665+
// try to resize, set ordering param to -1 cause it's not used anyway
666+
object new_array = reinterpret_steal<object>(
667+
detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1)
668+
);
669+
if (!new_array) throw error_already_set();
670+
if (isinstance<array>(new_array)) { *this = std::move(new_array); }
671+
}
672+
650673
/// Ensure that the argument is a NumPy array
651674
/// In case of an error, nullptr is returned and the Python error is cleared.
652675
static array ensure(handle h, int ExtraFlags = 0) {

tests/test_numpy_array.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,4 +267,25 @@ test_initializer numpy_array([](py::module &m) {
267267
// Issue #785: Uninformative "Unknown internal error" exception when constructing array from empty object:
268268
sm.def("array_fail_test", []() { return py::array(py::object()); });
269269
sm.def("array_t_fail_test", []() { return py::array_t<double>(py::object()); });
270+
271+
// reshape array to 2D without changing size
272+
sm.def("array_reshape2", [](py::array_t<double> a) {
273+
const size_t dim_sz = (size_t)std::sqrt(a.size());
274+
if (dim_sz * dim_sz != a.size())
275+
throw std::domain_error("array_reshape2: input array total size is not a squared integer");
276+
a.resize({dim_sz, dim_sz});
277+
});
278+
279+
// resize to 3D array with each dimension = N
280+
sm.def("array_resize3", [](py::array_t<double> a, size_t N, bool refcheck) {
281+
a.resize({N, N, N}, refcheck);
282+
});
283+
284+
// return 2D array with Nrows = Ncols = N
285+
sm.def("create_and_resize", [](size_t N) {
286+
py::array_t<double> a;
287+
a.resize({N, N});
288+
std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.);
289+
return a;
290+
});
270291
});

tests/test_numpy_array.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,34 @@ def test_array_failure():
389389
with pytest.raises(ValueError) as excinfo:
390390
array_t_fail_test()
391391
assert str(excinfo.value) == 'cannot create a pybind11::array_t from a nullptr'
392+
393+
394+
def test_array_resize(msg):
395+
from pybind11_tests.array import (array_reshape2, array_resize3, create_and_resize)
396+
397+
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float64')
398+
array_reshape2(a)
399+
assert(a.size == 9)
400+
assert(np.all(a == [[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
401+
402+
# total size change should succced with refcheck off
403+
array_resize3(a, 4, False)
404+
assert(a.size == 64)
405+
# ... and fail with refcheck on
406+
try:
407+
array_resize3(a, 3, True)
408+
except ValueError as e:
409+
assert(str(e).startswith("cannot resize an array"))
410+
# transposed array doesn't own data
411+
b = a.transpose()
412+
try:
413+
array_resize3(b, 3, False)
414+
except ValueError as e:
415+
assert(str(e).startswith("cannot resize this array: it does not own its data"))
416+
# ... but reshape should be fine
417+
array_reshape2(b)
418+
assert(b.shape == (8, 8))
419+
420+
a = create_and_resize(2)
421+
assert(a.size == 4)
422+
assert(np.all(a == 42.))

0 commit comments

Comments
 (0)