Skip to content

Commit bc38f89

Browse files
committed
add ascii to unicode and unicode to ascii casts
1 parent 9e86760 commit bc38f89

File tree

4 files changed

+308
-23
lines changed

4 files changed

+308
-23
lines changed

asciidtype/asciidtype/src/casts.c

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,190 @@ ascii_to_ascii_get_loop(PyArrayMethod_Context *NPY_UNUSED(context),
8787
return 0;
8888
}
8989

90+
static NPY_CASTING
91+
unicode_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self),
92+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
93+
PyArray_Descr *given_descrs[2],
94+
PyArray_Descr *loop_descrs[2],
95+
npy_intp *NPY_UNUSED(view_offset))
96+
{
97+
Py_INCREF(given_descrs[0]);
98+
loop_descrs[0] = given_descrs[0];
99+
if (given_descrs[1] == NULL) {
100+
Py_INCREF(given_descrs[0]);
101+
loop_descrs[1] = given_descrs[0];
102+
}
103+
else {
104+
Py_INCREF(given_descrs[1]);
105+
loop_descrs[1] = given_descrs[1];
106+
}
107+
108+
return NPY_SAME_KIND_CASTING;
109+
}
110+
111+
static int
112+
ucs4_character_is_ascii(char *buffer)
113+
{
114+
int first_char = buffer[0];
115+
116+
if (first_char < 0 || first_char > 127) {
117+
return -1;
118+
}
119+
120+
for (int i = 1; i < 4; i++) {
121+
if (buffer[i] != 0) {
122+
return -1;
123+
}
124+
}
125+
126+
return 0;
127+
}
128+
129+
static int
130+
unicode_to_ascii(PyArrayMethod_Context *context, char *const data[],
131+
npy_intp const dimensions[], npy_intp const strides[],
132+
NpyAuxData *NPY_UNUSED(auxdata))
133+
{
134+
PyArray_Descr **descrs = context->descriptors;
135+
long in_size = (descrs[0]->elsize) / 4;
136+
long out_size = ((ASCIIDTypeObject *)descrs[1])->size;
137+
long copy_size;
138+
139+
if (out_size > in_size) {
140+
copy_size = in_size;
141+
}
142+
else {
143+
copy_size = out_size;
144+
}
145+
146+
npy_intp N = dimensions[0];
147+
char *in = data[0];
148+
char *out = data[1];
149+
npy_intp in_stride = strides[0];
150+
npy_intp out_stride = strides[1];
151+
152+
while (N--) {
153+
// copy input characters, checking that input UCS4
154+
// characters are all ascii, raising an error otherwise
155+
for (int i = 0; i < copy_size; i++) {
156+
if (ucs4_character_is_ascii(in) == -1) {
157+
PyGILState_STATE gstate;
158+
gstate = PyGILState_Ensure();
159+
PyErr_SetString(
160+
PyExc_TypeError,
161+
"Can only store ASCII text in a ASCIIDType array.");
162+
PyGILState_Release(gstate);
163+
return -1;
164+
}
165+
// UCS4 character is ascii, so copy first byte of character
166+
// into output, ignoring the rest
167+
*(out + i) = *(in + i * 4);
168+
}
169+
// write zeros to remaining ASCII characters (if any)
170+
for (int i = copy_size; i < out_size; i++) {
171+
*(out + i) = '\0';
172+
}
173+
in += in_stride;
174+
out += out_stride;
175+
}
176+
177+
return 0;
178+
}
179+
180+
static int
181+
unicode_to_ascii_get_loop(PyArrayMethod_Context *NPY_UNUSED(context),
182+
int NPY_UNUSED(aligned),
183+
int NPY_UNUSED(move_references),
184+
const npy_intp *NPY_UNUSED(strides),
185+
PyArrayMethod_StridedLoop **out_loop,
186+
NpyAuxData **NPY_UNUSED(out_transferdata),
187+
NPY_ARRAYMETHOD_FLAGS *flags)
188+
{
189+
*out_loop = (PyArrayMethod_StridedLoop *)&unicode_to_ascii;
190+
191+
*flags = 0;
192+
return 0;
193+
}
194+
195+
static int
196+
ascii_to_unicode(PyArrayMethod_Context *context, char *const data[],
197+
npy_intp const dimensions[], npy_intp const strides[],
198+
NpyAuxData *NPY_UNUSED(auxdata))
199+
{
200+
PyArray_Descr **descrs = context->descriptors;
201+
long in_size = ((ASCIIDTypeObject *)descrs[0])->size;
202+
long out_size = (descrs[1]->elsize) / 4;
203+
long copy_size;
204+
205+
if (out_size > in_size) {
206+
copy_size = in_size;
207+
}
208+
else {
209+
copy_size = out_size;
210+
}
211+
212+
npy_intp N = dimensions[0];
213+
char *in = data[0];
214+
char *out = data[1];
215+
npy_intp in_stride = strides[0];
216+
npy_intp out_stride = strides[1];
217+
218+
while (N--) {
219+
// copy ASCII input to first byte, fill rest with zeros
220+
for (int i = 0; i < copy_size; i++) {
221+
*(out + i * 4) = *(in + i);
222+
for (int j = 1; j < 4; j++) {
223+
*(out + i * 4 + j) = '\0';
224+
}
225+
}
226+
// fill all remaining UCS4 characters with zeros
227+
for (int i = copy_size; i < out_size; i++) {
228+
for (int j = 0; j < 4; j++) {
229+
*(out + i * 4 + j) = '\0';
230+
}
231+
}
232+
in += in_stride;
233+
out += out_stride;
234+
}
235+
return 0;
236+
}
237+
238+
static NPY_CASTING
239+
ascii_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self),
240+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
241+
PyArray_Descr *given_descrs[2],
242+
PyArray_Descr *loop_descrs[2],
243+
npy_intp *NPY_UNUSED(view_offset))
244+
{
245+
Py_INCREF(given_descrs[0]);
246+
loop_descrs[0] = given_descrs[0];
247+
if (given_descrs[1] == NULL) {
248+
Py_INCREF(given_descrs[0]);
249+
loop_descrs[1] = given_descrs[0];
250+
}
251+
else {
252+
Py_INCREF(given_descrs[1]);
253+
loop_descrs[1] = given_descrs[1];
254+
}
255+
256+
return NPY_SAME_KIND_CASTING;
257+
}
258+
259+
static int
260+
ascii_to_unicode_get_loop(PyArrayMethod_Context *NPY_UNUSED(context),
261+
int NPY_UNUSED(aligned),
262+
int NPY_UNUSED(move_references),
263+
const npy_intp *NPY_UNUSED(strides),
264+
PyArrayMethod_StridedLoop **out_loop,
265+
NpyAuxData **NPY_UNUSED(out_transferdata),
266+
NPY_ARRAYMETHOD_FLAGS *flags)
267+
{
268+
*out_loop = (PyArrayMethod_StridedLoop *)&ascii_to_unicode;
269+
270+
*flags = 0;
271+
return 0;
272+
}
273+
90274
static PyArray_DTypeMeta *a2a_dtypes[2] = {NULL, NULL};
91275

92276
static PyType_Slot a2a_slots[] = {
@@ -103,3 +287,59 @@ PyArrayMethod_Spec ASCIIToASCIICastSpec = {
103287
.dtypes = a2a_dtypes,
104288
.slots = a2a_slots,
105289
};
290+
291+
static PyType_Slot u2a_slots[] = {
292+
{NPY_METH_resolve_descriptors, &unicode_to_ascii_resolve_descriptors},
293+
{_NPY_METH_get_loop, &unicode_to_ascii_get_loop},
294+
{0, NULL}};
295+
296+
static char *u2a_name = "cast_Unicode_to_ASCIIDType";
297+
298+
static PyType_Slot a2u_slots[] = {
299+
{NPY_METH_resolve_descriptors, &ascii_to_unicode_resolve_descriptors},
300+
{_NPY_METH_get_loop, &ascii_to_unicode_get_loop},
301+
{0, NULL}};
302+
303+
static char *a2u_name = "cast_ASCIIDType_to_Unicode";
304+
305+
PyArrayMethod_Spec **
306+
get_casts(void)
307+
{
308+
PyArray_DTypeMeta **u2a_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
309+
u2a_dtypes[0] = &PyArray_UnicodeDType;
310+
u2a_dtypes[1] = NULL;
311+
312+
PyArrayMethod_Spec *UnicodeToASCIICastSpec =
313+
malloc(sizeof(PyArrayMethod_Spec));
314+
315+
UnicodeToASCIICastSpec->name = u2a_name;
316+
UnicodeToASCIICastSpec->nin = 1;
317+
UnicodeToASCIICastSpec->nout = 1;
318+
UnicodeToASCIICastSpec->flags = NPY_METH_SUPPORTS_UNALIGNED;
319+
UnicodeToASCIICastSpec->casting = NPY_SAME_KIND_CASTING;
320+
UnicodeToASCIICastSpec->dtypes = u2a_dtypes;
321+
UnicodeToASCIICastSpec->slots = u2a_slots;
322+
323+
PyArray_DTypeMeta **a2u_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
324+
a2u_dtypes[0] = NULL;
325+
a2u_dtypes[1] = &PyArray_UnicodeDType;
326+
327+
PyArrayMethod_Spec *ASCIIToUnicodeCastSpec =
328+
malloc(sizeof(PyArrayMethod_Spec));
329+
330+
ASCIIToUnicodeCastSpec->name = a2u_name;
331+
ASCIIToUnicodeCastSpec->nin = 1;
332+
ASCIIToUnicodeCastSpec->nout = 1;
333+
ASCIIToUnicodeCastSpec->flags = NPY_METH_SUPPORTS_UNALIGNED;
334+
ASCIIToUnicodeCastSpec->casting = NPY_SAME_KIND_CASTING;
335+
ASCIIToUnicodeCastSpec->dtypes = a2u_dtypes;
336+
ASCIIToUnicodeCastSpec->slots = a2u_slots;
337+
338+
PyArrayMethod_Spec **casts = malloc(4 * sizeof(PyArrayMethod_Spec *));
339+
casts[0] = &ASCIIToASCIICastSpec;
340+
casts[1] = UnicodeToASCIICastSpec;
341+
casts[2] = ASCIIToUnicodeCastSpec;
342+
casts[3] = NULL;
343+
344+
return casts;
345+
}

asciidtype/asciidtype/src/casts.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "numpy/experimental_dtype_api.h"
1111
#include "numpy/ndarraytypes.h"
1212

13-
extern PyArrayMethod_Spec ASCIIToASCIICastSpec;
13+
PyArrayMethod_Spec **
14+
get_casts(void);
1415

1516
#endif /* _NPY_CASTS_H */

asciidtype/asciidtype/src/dtype.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ PyArray_DTypeMeta ASCIIDType = {
245245
int
246246
init_ascii_dtype(void)
247247
{
248-
static PyArrayMethod_Spec *casts[] = {&ASCIIToASCIICastSpec, NULL};
248+
PyArrayMethod_Spec **casts = get_casts();
249249

250250
PyArrayDTypeMeta_Spec ASCIIDType_DTypeSpec = {
251251
.flags = NPY_DT_PARAMETRIC,
@@ -273,5 +273,11 @@ init_ascii_dtype(void)
273273

274274
ASCIIDType.singleton = singleton;
275275

276+
free(ASCIIDType_DTypeSpec.casts[1]->dtypes);
277+
free(ASCIIDType_DTypeSpec.casts[1]);
278+
free(ASCIIDType_DTypeSpec.casts[2]->dtypes);
279+
free(ASCIIDType_DTypeSpec.casts[2]);
280+
free(ASCIIDType_DTypeSpec.casts);
281+
276282
return 0;
277283
}

asciidtype/tests/test_asciidtype.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
from asciidtype import ASCIIDType, ASCIIScalar
45

@@ -50,24 +51,61 @@ def test_creation_truncation():
5051

5152

5253
def test_casting_to_asciidtype():
53-
arr = np.array(["hello", "this", "is", "an", "array"], dtype=ASCIIDType(5))
54-
55-
assert repr(arr.astype(ASCIIDType(7))) == (
56-
"array(['hello', 'this', 'is', 'an', 'array'], dtype=ASCIIDType(7))"
57-
)
58-
59-
assert repr(arr.astype(ASCIIDType(5))) == (
60-
"array(['hello', 'this', 'is', 'an', 'array'], dtype=ASCIIDType(5))"
61-
)
62-
63-
assert repr(arr.astype(ASCIIDType(4))) == (
64-
"array(['hell', 'this', 'is', 'an', 'arra'], dtype=ASCIIDType(4))"
65-
)
66-
67-
assert repr(arr.astype(ASCIIDType(1))) == (
68-
"array(['h', 't', 'i', 'a', 'a'], dtype=ASCIIDType(1))"
69-
)
70-
71-
# assert repr(arr.astype(ASCIIDType())) == (
72-
# "array(['', '', '', '', ''], dtype=ASCIIDType(0))"
73-
# )
54+
for dtype in (None, ASCIIDType(5)):
55+
arr = np.array(["this", "is", "an", "array"], dtype=dtype)
56+
57+
assert repr(arr.astype(ASCIIDType(7))) == (
58+
"array(['this', 'is', 'an', 'array'], dtype=ASCIIDType(7))"
59+
)
60+
61+
assert repr(arr.astype(ASCIIDType(5))) == (
62+
"array(['this', 'is', 'an', 'array'], dtype=ASCIIDType(5))"
63+
)
64+
65+
assert repr(arr.astype(ASCIIDType(4))) == (
66+
"array(['this', 'is', 'an', 'arra'], dtype=ASCIIDType(4))"
67+
)
68+
69+
assert repr(arr.astype(ASCIIDType(1))) == (
70+
"array(['t', 'i', 'a', 'a'], dtype=ASCIIDType(1))"
71+
)
72+
73+
# assert repr(arr.astype(ASCIIDType())) == (
74+
# "array(['', '', '', '', ''], dtype=ASCIIDType(0))"
75+
# )
76+
77+
78+
def test_unicode_to_ascii_to_unicode():
79+
arr = np.array(["hello", "this", "is", "an", "array"])
80+
ascii_arr = arr.astype(ASCIIDType(5))
81+
round_trip_arr = ascii_arr.astype("U5")
82+
np.testing.assert_array_equal(arr, round_trip_arr)
83+
84+
85+
def test_creation_fails_with_non_ascii_characters():
86+
inps = [
87+
["😀", "¡", "©", "ÿ"],
88+
["😀", "hello", "some", "ascii"],
89+
["hello", "some", "ascii", "😀"],
90+
]
91+
for inp in inps:
92+
with pytest.raises(
93+
TypeError,
94+
match="Can only store ASCII text in a ASCIIDType array.",
95+
):
96+
np.array(inp, dtype=ASCIIDType(5))
97+
98+
99+
def test_casting_fails_with_non_ascii_characters():
100+
inps = [
101+
["😀", "¡", "©", "ÿ"],
102+
["😀", "hello", "some", "ascii"],
103+
["hello", "some", "ascii", "😀"],
104+
]
105+
for inp in inps:
106+
arr = np.array(inp)
107+
with pytest.raises(
108+
TypeError,
109+
match="Can only store ASCII text in a ASCIIDType array.",
110+
):
111+
arr.astype(ASCIIDType(5))

0 commit comments

Comments
 (0)