Skip to content

Commit df1be99

Browse files
author
Bertrand MICHEL
committed
[dtype]: add type() method to access type attribute of PyArray_Descr (eq. to dtype.char in Python)
1 parent 417067e commit df1be99

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

include/pybind11/numpy.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,11 +507,16 @@ class dtype : public object {
507507
return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
508508
}
509509

510-
/// Single-character type code.
510+
/// Single-character for dtype's kind (ex: float and double are 'f' or int and long int are 'i')
511511
char kind() const {
512512
return detail::array_descriptor_proxy(m_ptr)->kind;
513513
}
514514

515+
/// Single-character for dtype's type (ex: float is 'f' and double 'd')
516+
char type() const {
517+
return detail::array_descriptor_proxy(m_ptr)->type;
518+
}
519+
515520
private:
516521
static object _dtype_from_pep3118() {
517522
static PyObject *obj = module_::import("numpy.core._internal")

tests/test_numpy_dtypes.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,38 @@ py::list test_dtype_ctors() {
251251
return list;
252252
}
253253

254+
py::list test_dtype_kind() {
255+
py::list list;
256+
for (auto& dt : {
257+
py::dtype("bool8"), // bool
258+
py::dtype("int16"), // short
259+
py::dtype("int32"), // int
260+
py::dtype("int64"), // long int
261+
py::dtype("float32"), // float
262+
py::dtype("float64"), // double
263+
py::dtype("float128") // long double
264+
}) {
265+
list.append(dt.kind());
266+
}
267+
return list;
268+
}
269+
270+
py::list test_dtype_type() {
271+
py::list list;
272+
for (auto& dt : {
273+
py::dtype("bool8"), // bool
274+
py::dtype("int16"), // short
275+
py::dtype("int32"), // int
276+
py::dtype("int64"), // long int
277+
py::dtype("float32"), // float
278+
py::dtype("float64"), // double
279+
py::dtype("float128") // long double
280+
}) {
281+
list.append(dt.type());
282+
}
283+
return list;
284+
}
285+
254286
struct A {};
255287
struct B {};
256288

@@ -376,6 +408,8 @@ TEST_SUBMODULE(numpy_dtypes, m) {
376408
return l;
377409
});
378410
m.def("test_dtype_ctors", &test_dtype_ctors);
411+
m.def("test_dtype_kind", &test_dtype_kind);
412+
m.def("test_dtype_type", &test_dtype_type);
379413
m.def("test_dtype_methods", []() {
380414
py::list list;
381415
auto dt1 = py::dtype::of<int32_t>();

tests/test_numpy_dtypes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def test_dtype(simple_dtype):
169169
np.zeros(1, m.trailing_padding_dtype())
170170
)
171171

172+
# for dt in m.test_dtype_kind():
173+
# print("dt = ",dt)
174+
assert m.test_dtype_kind() == ['b'] + ['i']*3 + ['f']*3
175+
assert m.test_dtype_type() == ['?', 'h', 'i', 'l', 'f', 'd', 'g']
176+
# assert False
177+
172178

173179
def test_recarray(simple_dtype, packed_dtype):
174180
elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)]

0 commit comments

Comments
 (0)