Skip to content

Commit 92353e5

Browse files
authored
Merge pull request #13 from ngoldbaum/add-asciidtype
Add round-trip casts between unicode and ASCIIDType
2 parents 914afe8 + dae600b commit 92353e5

File tree

6 files changed

+353
-45
lines changed

6 files changed

+353
-45
lines changed

asciidtype/.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[flake8]
22
per-file-ignores = __init__.py:F401
3+
max-line-length = 160

asciidtype/asciidtype/src/casts.c

Lines changed: 210 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,17 @@ ascii_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self),
2828
loop_descrs[1] = given_descrs[1];
2929
}
3030

31-
if (((ASCIIDTypeObject *)loop_descrs[0])->size ==
32-
((ASCIIDTypeObject *)loop_descrs[1])->size) {
31+
long in_size = ((ASCIIDTypeObject *)loop_descrs[0])->size;
32+
long out_size = ((ASCIIDTypeObject *)loop_descrs[1])->size;
33+
34+
if (in_size == out_size) {
3335
*view_offset = 0;
3436
return NPY_NO_CASTING;
3537
}
36-
37-
return NPY_SAME_KIND_CASTING;
38+
else if (in_size > out_size) {
39+
return NPY_UNSAFE_CASTING;
40+
}
41+
return NPY_SAFE_CASTING;
3842
}
3943

4044
static int
@@ -72,33 +76,224 @@ ascii_to_ascii(PyArrayMethod_Context *context, char *const data[],
7276
return 0;
7377
}
7478

79+
static NPY_CASTING
80+
unicode_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self),
81+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
82+
PyArray_Descr *given_descrs[2],
83+
PyArray_Descr *loop_descrs[2],
84+
npy_intp *NPY_UNUSED(view_offset))
85+
{
86+
Py_INCREF(given_descrs[0]);
87+
loop_descrs[0] = given_descrs[0];
88+
// numpy stores unicode as UCS4 (4 bytes wide), so bitshift
89+
// by 2 to get the number of ASCII bytes needed
90+
long in_size = (loop_descrs[0]->elsize) >> 2;
91+
if (given_descrs[1] == NULL) {
92+
ASCIIDTypeObject *ascii_descr = new_asciidtype_instance(in_size);
93+
loop_descrs[1] = (PyArray_Descr *)ascii_descr;
94+
}
95+
else {
96+
Py_INCREF(given_descrs[1]);
97+
loop_descrs[1] = given_descrs[1];
98+
}
99+
100+
long out_size = ((ASCIIDTypeObject *)loop_descrs[1])->size;
101+
102+
if (out_size >= in_size) {
103+
return NPY_SAFE_CASTING;
104+
}
105+
106+
return NPY_UNSAFE_CASTING;
107+
}
108+
75109
static int
76-
ascii_to_ascii_get_loop(PyArrayMethod_Context *context, int aligned,
77-
int NPY_UNUSED(move_references),
78-
const npy_intp *strides,
79-
PyArrayMethod_StridedLoop **out_loop,
80-
NpyAuxData **NPY_UNUSED(out_transferdata),
81-
NPY_ARRAYMETHOD_FLAGS *flags)
110+
unicode_to_ascii(PyArrayMethod_Context *context, char *const data[],
111+
npy_intp const dimensions[], npy_intp const strides[],
112+
NpyAuxData *NPY_UNUSED(auxdata))
82113
{
83-
*out_loop = (PyArrayMethod_StridedLoop *)&ascii_to_ascii;
114+
PyArray_Descr **descrs = context->descriptors;
115+
long in_size = (descrs[0]->elsize) / 4;
116+
long out_size = ((ASCIIDTypeObject *)descrs[1])->size;
117+
long copy_size;
118+
119+
if (out_size > in_size) {
120+
copy_size = in_size;
121+
}
122+
else {
123+
copy_size = out_size;
124+
}
125+
126+
npy_intp N = dimensions[0];
127+
char *in = data[0];
128+
char *out = data[1];
129+
npy_intp in_stride = strides[0];
130+
npy_intp out_stride = strides[1];
131+
132+
while (N--) {
133+
// copy input characters, checking that input UCS4
134+
// characters are all ascii, raising an error otherwise
135+
for (int i = 0; i < copy_size; i++) {
136+
Py_UCS4 c = ((Py_UCS4 *)in)[i];
137+
if (c > 127) {
138+
PyErr_SetString(
139+
PyExc_TypeError,
140+
"Can only store ASCII text in a ASCIIDType array.");
141+
return -1;
142+
}
143+
// UCS4 character is ascii, so casting to Py_UCS1 does not truncate
144+
out[i] = (Py_UCS1)c;
145+
}
146+
// write zeros to remaining ASCII characters (if any)
147+
for (int i = copy_size; i < out_size; i++) {
148+
*(out + i) = '\0';
149+
}
150+
in += in_stride;
151+
out += out_stride;
152+
}
84153

85-
*flags = 0;
86154
return 0;
87155
}
88156

157+
static int
158+
ascii_to_unicode(PyArrayMethod_Context *context, char *const data[],
159+
npy_intp const dimensions[], npy_intp const strides[],
160+
NpyAuxData *NPY_UNUSED(auxdata))
161+
{
162+
PyArray_Descr **descrs = context->descriptors;
163+
long in_size = ((ASCIIDTypeObject *)descrs[0])->size;
164+
long out_size = (descrs[1]->elsize) / 4;
165+
long copy_size;
166+
167+
if (out_size > in_size) {
168+
copy_size = in_size;
169+
}
170+
else {
171+
copy_size = out_size;
172+
}
173+
174+
npy_intp N = dimensions[0];
175+
char *in = data[0];
176+
char *out = data[1];
177+
npy_intp in_stride = strides[0];
178+
npy_intp out_stride = strides[1];
179+
180+
while (N--) {
181+
// copy ASCII input to first byte, fill rest with zeros
182+
for (int i = 0; i < copy_size; i++) {
183+
((Py_UCS4 *)out)[i] = ((Py_UCS1 *)in)[i];
184+
}
185+
// fill all remaining UCS4 characters with zeros
186+
for (int i = copy_size; i < out_size; i++) {
187+
((Py_UCS4 *)out)[i] = (Py_UCS1)0;
188+
}
189+
in += in_stride;
190+
out += out_stride;
191+
}
192+
return 0;
193+
}
194+
195+
static NPY_CASTING
196+
ascii_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self),
197+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
198+
PyArray_Descr *given_descrs[2],
199+
PyArray_Descr *loop_descrs[2],
200+
npy_intp *NPY_UNUSED(view_offset))
201+
{
202+
Py_INCREF(given_descrs[0]);
203+
loop_descrs[0] = given_descrs[0];
204+
long in_size = ((ASCIIDTypeObject *)given_descrs[0])->size;
205+
if (given_descrs[1] == NULL) {
206+
PyArray_Descr *unicode_descr = PyArray_DescrNewFromType(NPY_UNICODE);
207+
// numpy stores unicode as UCS4 (4 bytes wide), so bitshift
208+
// by 2 to get the number of bytes needed to store the UCS4 charaters
209+
unicode_descr->elsize = in_size << 2;
210+
loop_descrs[1] = unicode_descr;
211+
}
212+
else {
213+
Py_INCREF(given_descrs[1]);
214+
loop_descrs[1] = given_descrs[1];
215+
}
216+
217+
long out_size = (loop_descrs[1]->elsize) >> 2;
218+
219+
if (out_size >= in_size) {
220+
return NPY_SAFE_CASTING;
221+
}
222+
223+
return NPY_UNSAFE_CASTING;
224+
}
225+
89226
static PyArray_DTypeMeta *a2a_dtypes[2] = {NULL, NULL};
90227

91228
static PyType_Slot a2a_slots[] = {
92229
{NPY_METH_resolve_descriptors, &ascii_to_ascii_resolve_descriptors},
93-
{_NPY_METH_get_loop, &ascii_to_ascii_get_loop},
230+
{NPY_METH_strided_loop, &ascii_to_ascii},
231+
{NPY_METH_unaligned_strided_loop, &ascii_to_ascii},
94232
{0, NULL}};
95233

96234
PyArrayMethod_Spec ASCIIToASCIICastSpec = {
97235
.name = "cast_ASCIIDType_to_ASCIIDType",
98236
.nin = 1,
99237
.nout = 1,
100-
.flags = NPY_METH_SUPPORTS_UNALIGNED,
101-
.casting = NPY_SAME_KIND_CASTING,
238+
.casting = NPY_UNSAFE_CASTING,
239+
.flags = (NPY_METH_NO_FLOATINGPOINT_ERRORS |
240+
NPY_METH_SUPPORTS_UNALIGNED),
102241
.dtypes = a2a_dtypes,
103242
.slots = a2a_slots,
104243
};
244+
245+
static PyType_Slot u2a_slots[] = {
246+
{NPY_METH_resolve_descriptors, &unicode_to_ascii_resolve_descriptors},
247+
{NPY_METH_strided_loop, &unicode_to_ascii},
248+
{0, NULL}};
249+
250+
static char *u2a_name = "cast_Unicode_to_ASCIIDType";
251+
252+
static PyType_Slot a2u_slots[] = {
253+
{NPY_METH_resolve_descriptors, &ascii_to_unicode_resolve_descriptors},
254+
{NPY_METH_strided_loop, &ascii_to_unicode},
255+
{0, NULL}};
256+
257+
static char *a2u_name = "cast_ASCIIDType_to_Unicode";
258+
259+
PyArrayMethod_Spec **
260+
get_casts(void)
261+
{
262+
PyArray_DTypeMeta **u2a_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
263+
u2a_dtypes[0] = &PyArray_UnicodeDType;
264+
u2a_dtypes[1] = NULL;
265+
266+
PyArrayMethod_Spec *UnicodeToASCIICastSpec =
267+
malloc(sizeof(PyArrayMethod_Spec));
268+
269+
UnicodeToASCIICastSpec->name = u2a_name;
270+
UnicodeToASCIICastSpec->nin = 1;
271+
UnicodeToASCIICastSpec->nout = 1;
272+
UnicodeToASCIICastSpec->casting = NPY_UNSAFE_CASTING;
273+
UnicodeToASCIICastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
274+
UnicodeToASCIICastSpec->dtypes = u2a_dtypes;
275+
UnicodeToASCIICastSpec->slots = u2a_slots;
276+
277+
PyArray_DTypeMeta **a2u_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
278+
a2u_dtypes[0] = NULL;
279+
a2u_dtypes[1] = &PyArray_UnicodeDType;
280+
281+
PyArrayMethod_Spec *ASCIIToUnicodeCastSpec =
282+
malloc(sizeof(PyArrayMethod_Spec));
283+
284+
ASCIIToUnicodeCastSpec->name = a2u_name;
285+
ASCIIToUnicodeCastSpec->nin = 1;
286+
ASCIIToUnicodeCastSpec->nout = 1;
287+
ASCIIToUnicodeCastSpec->casting = NPY_UNSAFE_CASTING;
288+
ASCIIToUnicodeCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
289+
ASCIIToUnicodeCastSpec->dtypes = a2u_dtypes;
290+
ASCIIToUnicodeCastSpec->slots = a2u_slots;
291+
292+
PyArrayMethod_Spec **casts = malloc(4 * sizeof(PyArrayMethod_Spec *));
293+
casts[0] = &ASCIIToASCIICastSpec;
294+
casts[1] = UnicodeToASCIICastSpec;
295+
casts[2] = ASCIIToUnicodeCastSpec;
296+
casts[3] = NULL;
297+
298+
return casts;
299+
}

asciidtype/asciidtype/src/casts.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "numpy/experimental_dtype_api.h"
1111
#include "numpy/ndarraytypes.h"
1212

13-
extern PyArrayMethod_Spec ASCIIToASCIICastSpec;
13+
PyArrayMethod_Spec **
14+
get_casts(void);
1415

1516
#endif /* _NPY_CASTS_H */

asciidtype/asciidtype/src/dtype.c

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ get_value(PyObject *scalar)
1616
PyErr_SetString(
1717
PyExc_TypeError,
1818
"Can only store ASCII text in a ASCIIDType array.");
19+
return NULL;
1920
}
2021
}
2122
else if (scalar_type != ASCIIScalar_Type) {
@@ -29,6 +30,12 @@ get_value(PyObject *scalar)
2930
return NULL;
3031
}
3132
ret_bytes = PyUnicode_AsASCIIString(value);
33+
if (ret_bytes == NULL) {
34+
PyErr_SetString(
35+
PyExc_TypeError,
36+
"Can only store ASCII text in a ASCIIDType array.");
37+
return NULL;
38+
}
3239
Py_DECREF(value);
3340
}
3441
return ret_bytes;
@@ -38,20 +45,16 @@ get_value(PyObject *scalar)
3845
* Internal helper to create new instances
3946
*/
4047
ASCIIDTypeObject *
41-
new_asciidtype_instance(PyObject *size)
48+
new_asciidtype_instance(long size)
4249
{
4350
ASCIIDTypeObject *new = (ASCIIDTypeObject *)PyArrayDescr_Type.tp_new(
4451
(PyTypeObject *)&ASCIIDType, NULL, NULL);
4552
if (new == NULL) {
4653
return NULL;
4754
}
48-
long size_l = PyLong_AsLong(size);
49-
if (size_l == -1 && PyErr_Occurred()) {
50-
return NULL;
51-
}
52-
new->size = size_l;
53-
new->base.elsize = size_l * sizeof(char);
54-
new->base.alignment = size_l *_Alignof(char);
55+
new->size = size;
56+
new->base.elsize = size * sizeof(char);
57+
new->base.alignment = size *_Alignof(char);
5558

5659
return new;
5760
}
@@ -182,18 +185,14 @@ asciidtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds)
182185
{
183186
static char *kwargs_strs[] = {"size", NULL};
184187

185-
PyObject *size = NULL;
188+
long size = 0;
186189

187-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O:ASCIIDType", kwargs_strs,
190+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|l:ASCIIDType", kwargs_strs,
188191
&size)) {
189192
return NULL;
190193
}
191-
if (size == NULL) {
192-
size = PyLong_FromLong(0);
193-
}
194194

195195
PyObject *ret = (PyObject *)new_asciidtype_instance(size);
196-
Py_DECREF(size);
197196
return ret;
198197
}
199198

@@ -239,7 +238,7 @@ PyArray_DTypeMeta ASCIIDType = {
239238
int
240239
init_ascii_dtype(void)
241240
{
242-
static PyArrayMethod_Spec *casts[] = {&ASCIIToASCIICastSpec, NULL};
241+
PyArrayMethod_Spec **casts = get_casts();
243242

244243
PyArrayDTypeMeta_Spec ASCIIDType_DTypeSpec = {
245244
.flags = NPY_DT_PARAMETRIC,
@@ -267,5 +266,11 @@ init_ascii_dtype(void)
267266

268267
ASCIIDType.singleton = singleton;
269268

269+
free(ASCIIDType_DTypeSpec.casts[1]->dtypes);
270+
free(ASCIIDType_DTypeSpec.casts[1]);
271+
free(ASCIIDType_DTypeSpec.casts[2]->dtypes);
272+
free(ASCIIDType_DTypeSpec.casts[2]);
273+
free(ASCIIDType_DTypeSpec.casts);
274+
270275
return 0;
271276
}

asciidtype/asciidtype/src/dtype.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ extern PyArray_DTypeMeta ASCIIDType;
2222
extern PyTypeObject *ASCIIScalar_Type;
2323

2424
ASCIIDTypeObject *
25-
new_asciidtype_instance(PyObject *size);
25+
new_asciidtype_instance(long size);
2626

2727
int
2828
init_ascii_dtype(void);

0 commit comments

Comments
 (0)