Skip to content

Commit 914afe8

Browse files
authored
Merge pull request #12 from numpy/fix-quaddtype-debug-symbols
Fix quaddtype debug symbols
2 parents 3b0d427 + a348a91 commit 914afe8

File tree

6 files changed

+99
-28
lines changed

6 files changed

+99
-28
lines changed
File renamed without changes.

quaddtype/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ requires = [
44
"meson-python",
55
"patchelf",
66
"wheel",
7-
"numpy @ file:///home/pdmurray/Desktop/numpy-1.24.0.dev0+1120.gf30af6acd-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
7+
"numpy @ file:///home/pdmurray/Desktop/numpy-1.25.0.dev0+167.g4ca8204c5-cp311-cp311d-linux_x86_64.whl"
88
]
99
build-backend = "mesonpy"
1010

@@ -16,7 +16,7 @@ readme = 'README.md'
1616
author = "Peyton Murray"
1717
requires-python = ">=3.9.0"
1818
dependencies = [
19-
"numpy @ file:///home/pdmurray/Desktop/numpy-1.24.0.dev0+1120.gf30af6acd-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
19+
"numpy @ file:///home/pdmurray/Desktop/numpy-1.25.0.dev0+167.g4ca8204c5-cp311-cp311d-linux_x86_64.whl"
2020
]
2121

2222
[project.optional-dependencies]

quaddtype/quaddtype/src/casts.c

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,69 @@ quad_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self),
2929
return NPY_SAME_KIND_CASTING;
3030
}
3131

32+
// Each element is a __float128 element; no casting needed
3233
static int
3334
quad_to_quad_contiguous(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[],
34-
npy_intp const dimensions[], npy_intp const strides[], void *auxdata)
35+
npy_intp const dimensions[], npy_intp const strides[],
36+
NpyAuxData *NPY_UNUSED(auxdata))
3537
{
38+
npy_intp N = dimensions[0];
39+
__float128 *in = (__float128 *)data[0];
40+
__float128 *out = (__float128 *)data[1];
41+
42+
while (N--) {
43+
*out = *in;
44+
out++;
45+
in++;
46+
}
47+
3648
return 0;
3749
}
3850

51+
// Elements are strided, e.g.
52+
//
53+
// x = np.linspace(40)
54+
// x[::3]
55+
//
56+
// Therefore the stride needs to be used to increment the pointers inside the loop.
3957
static int
4058
quad_to_quad_strided(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[],
41-
npy_intp const dimensions[], npy_intp const strides[], void *auxdata)
59+
npy_intp const dimensions[], npy_intp const strides[],
60+
NpyAuxData *NPY_UNUSED(auxdata))
4261
{
62+
npy_intp N = dimensions[0];
63+
char *in = data[0];
64+
char *out = data[1];
65+
npy_intp in_stride = strides[0];
66+
npy_intp out_stride = strides[1];
67+
68+
while (N--) {
69+
*(__float128 *)out = *(__float128 *)in;
70+
in += in_stride;
71+
out += out_stride;
72+
}
73+
4374
return 0;
4475
}
4576

77+
// Arrays are unaligned.
4678
static int
4779
quad_to_quad_unaligned(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[],
48-
npy_intp const dimensions[], npy_intp const strides[], void *auxdata)
80+
npy_intp const dimensions[], npy_intp const strides[],
81+
NpyAuxData *NPY_UNUSED(auxdata))
4982
{
83+
npy_intp N = dimensions[0];
84+
char *in = data[0];
85+
char *out = data[1];
86+
npy_intp in_stride = strides[0];
87+
npy_intp out_stride = strides[1];
88+
89+
while (N--) {
90+
memcpy(out, in, sizeof(__float128)); // NOLINT
91+
in += in_stride;
92+
out += out_stride;
93+
}
94+
5095
return 0;
5196
}
5297

quaddtype/quaddtype/src/dtype.c

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,5 @@
1-
#include <Python.h>
2-
3-
#define PY_ARRAY_UNIQUE_SYMBOL quaddtype_ARRAY_API
4-
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
5-
#define NO_IMPORT_ARRAY
6-
#include "numpy/arrayobject.h"
7-
#include "numpy/experimental_dtype_api.h"
8-
#include "numpy/ndarraytypes.h"
9-
10-
#include "casts.h"
111
#include "dtype.h"
2+
#include "casts.h"
123

134
PyTypeObject *QuadScalar_Type = NULL;
145

@@ -17,8 +8,9 @@ new_quaddtype_instance(void)
178
{
189
QuadDTypeObject *new =
1910
(QuadDTypeObject *)PyArrayDescr_Type.tp_new((PyTypeObject *)&QuadDType, NULL, NULL);
20-
if (new == NULL)
11+
if (new == NULL) {
2112
return NULL;
13+
}
2214

2315
new->base.elsize = sizeof(__float128);
2416
new->base.alignment = _Alignof(__float128);
@@ -82,14 +74,34 @@ common_dtype(PyArray_DTypeMeta *self, PyArray_DTypeMeta *other)
8274
return (PyArray_DTypeMeta *)Py_NotImplemented;
8375
}
8476

85-
static PyType_Slot QuadDType_Slots[] = {{NPY_DT_common_instance, &common_instance},
86-
{NPY_DT_common_dtype, &common_dtype},
87-
// {NPY_DT_discover_descr_from_pyobject,
88-
// &unit_discover_descriptor_from_pyobject},
89-
/* The header is wrong on main :(, so we add 1 */
90-
{NPY_DT_setitem, &quad_setitem},
91-
{NPY_DT_getitem, &quad_getitem},
92-
{0, NULL}};
77+
static QuadDTypeObject *
78+
quaddtype_ensure_canonical(QuadDTypeObject *self)
79+
{
80+
Py_INCREF(self);
81+
return self;
82+
}
83+
84+
static PyArray_Descr *
85+
quad_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls), PyObject *obj)
86+
{
87+
if (Py_TYPE(obj) != QuadScalar_Type) {
88+
PyErr_SetString(PyExc_TypeError, "Can only store QuadScalars in a QuadDType array.");
89+
return NULL;
90+
}
91+
92+
// Get the dtype attribute from the object.
93+
return (PyArray_Descr *)PyObject_GetAttrString(obj, "dtype");
94+
}
95+
96+
static PyType_Slot QuadDType_Slots[] = {
97+
{NPY_DT_common_instance, &common_instance},
98+
{NPY_DT_common_dtype, &common_dtype},
99+
{NPY_DT_discover_descr_from_pyobject, &quad_discover_descriptor_from_pyobject},
100+
/* The header is wrong on main :(, so we add 1 */
101+
{NPY_DT_setitem, &quad_setitem},
102+
{NPY_DT_getitem, &quad_getitem},
103+
{NPY_DT_ensure_canonical, &quaddtype_ensure_canonical},
104+
{0, NULL}};
93105

94106
/*
95107
* The following defines everything type object related (i.e. not NumPy
@@ -114,7 +126,8 @@ quaddtype_dealloc(QuadDTypeObject *self)
114126
static PyObject *
115127
quaddtype_repr(QuadDTypeObject *self)
116128
{
117-
return PyUnicode_FromString("This is a quad (128-bit float) dtype.");
129+
PyObject *res = PyUnicode_FromString("This is a quad (128-bit float) dtype.");
130+
return res;
118131
}
119132

120133
// These are the basic things that you need to create a Python Type/Class in C.
@@ -140,7 +153,6 @@ init_quad_dtype(void)
140153
// do it. You first have to create a static type, but see the note there!
141154
PyArrayMethod_Spec *casts[] = {
142155
&QuadToQuadCastSpec,
143-
&QuadToFloat128CastSpec,
144156
NULL,
145157
};
146158

quaddtype/quaddtype/src/dtype.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
#include <Python.h>
55

6+
#define PY_ARRAY_UNIQUE_SYMBOL quaddtype_ARRAY_API
7+
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
8+
#define NO_IMPORT_ARRAY
9+
#include "numpy/ndarraytypes.h"
610
#include "numpy/arrayobject.h"
711
#include "numpy/experimental_dtype_api.h"
812

@@ -15,6 +19,7 @@ extern PyTypeObject *QuadScalar_Type;
1519

1620
QuadDTypeObject *
1721
new_quaddtype_instance(void);
22+
1823
int
1924
init_quad_dtype(void);
2025

quaddtype/quaddtype/src/umath.c

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ quad_multiply_strided_loop(PyArrayMethod_Context *context, char *const data[],
2525
char *in1 = data[0], *in2 = data[1];
2626
char *out = data[2];
2727
npy_intp in1_stride = strides[0];
28-
npy_intp in2_stride = strides[0];
28+
npy_intp in2_stride = strides[1];
2929
npy_intp out_stride = strides[2];
3030

3131
while (N--) {
@@ -56,9 +56,18 @@ quad_multiply_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *dtypes[],
5656
// The operand units can be used as-is; no casting required for quad types.
5757
Py_INCREF(given_descrs[0]);
5858
loop_descrs[0] = given_descrs[0];
59+
60+
if (given_descrs[1] == NULL) {
61+
Py_INCREF(given_descrs[0]);
62+
loop_descrs[1] = given_descrs[0];
63+
return NPY_NO_CASTING;
64+
}
65+
Py_INCREF(given_descrs[1]);
66+
loop_descrs[1] = given_descrs[1];
67+
5968
Py_INCREF(given_descrs[1]);
6069
loop_descrs[1] = given_descrs[1];
61-
return NPY_NO_CASTING;
70+
return NPY_SAFE_CASTING;
6271
}
6372

6473
// Function that adds our multiply loop to NumPy's multiply ufunc.

0 commit comments

Comments
 (0)