|
4 | 4 |
|
5 | 5 | PyTypeObject *ASCIIScalar_Type = NULL; |
6 | 6 |
|
7 | | -static char * |
| 7 | +static PyObject * |
8 | 8 | get_value(PyObject *scalar) |
9 | 9 | { |
| 10 | + PyObject *ret_bytes = NULL; |
10 | 11 | PyTypeObject *scalar_type = Py_TYPE(scalar); |
11 | | - if (scalar_type != ASCIIScalar_Type) { |
12 | | - PyErr_SetString(PyExc_TypeError, |
13 | | - "Can only store ASCIIScalar in a ASCIIDType array."); |
14 | | - return NULL; |
| 12 | + if (scalar_type == &PyUnicode_Type) { |
| 13 | + // attempt to decode as ASCII |
| 14 | + ret_bytes = PyUnicode_AsASCIIString(scalar); |
| 15 | + if (ret_bytes == NULL) { |
| 16 | + PyErr_SetString( |
| 17 | + PyExc_TypeError, |
| 18 | + "Can only store ASCII text in a ASCIIDType array."); |
| 19 | + } |
15 | 20 | } |
16 | | - |
17 | | - PyObject *value = PyObject_GetAttrString(scalar, "value"); |
18 | | - if (value == NULL) { |
| 21 | + else if (scalar_type != ASCIIScalar_Type) { |
| 22 | + PyErr_SetString(PyExc_TypeError, |
| 23 | + "Can only store ASCII text in a ASCIIDType array."); |
19 | 24 | return NULL; |
20 | 25 | } |
21 | | - PyObject *res_bytes = PyUnicode_AsASCIIString(value); |
22 | | - Py_DECREF(value); |
23 | | - char *res = PyBytes_AsString(res_bytes); |
24 | | - Py_DECREF(res_bytes); |
25 | | - if (res == NULL) { |
26 | | - return NULL; |
| 26 | + else { |
| 27 | + PyObject *value = PyObject_GetAttrString(scalar, "value"); |
| 28 | + if (value == NULL) { |
| 29 | + return NULL; |
| 30 | + } |
| 31 | + ret_bytes = PyUnicode_AsASCIIString(value); |
| 32 | + Py_DECREF(value); |
27 | 33 | } |
28 | | - return res; |
| 34 | + return ret_bytes; |
29 | 35 | } |
30 | 36 |
|
31 | 37 | /* |
@@ -100,24 +106,39 @@ ascii_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls), |
100 | 106 | static int |
101 | 107 | asciidtype_setitem(ASCIIDTypeObject *descr, PyObject *obj, char *dataptr) |
102 | 108 | { |
103 | | - char *value = get_value(obj); |
| 109 | + PyObject *value = get_value(obj); |
104 | 110 | if (value == NULL) { |
105 | 111 | return -1; |
106 | 112 | } |
107 | 113 |
|
108 | | - memcpy(dataptr, value, descr->size * sizeof(char)); // NOLINT |
| 114 | + Py_ssize_t len = PyBytes_Size(value); |
| 115 | + |
| 116 | + size_t copysize; |
| 117 | + |
| 118 | + if (len > descr->size) { |
| 119 | + copysize = descr->size; |
| 120 | + } |
| 121 | + else { |
| 122 | + copysize = len; |
| 123 | + } |
| 124 | + |
| 125 | + char *char_value = PyBytes_AsString(value); |
| 126 | + |
| 127 | + memcpy(dataptr, char_value, copysize * sizeof(char)); // NOLINT |
| 128 | + |
| 129 | + for (int i = copysize; i < descr->size; i++) { |
| 130 | + dataptr[i] = '\0'; |
| 131 | + } |
| 132 | + |
| 133 | + Py_DECREF(value); |
109 | 134 |
|
110 | 135 | return 0; |
111 | 136 | } |
112 | 137 |
|
113 | 138 | static PyObject * |
114 | 139 | asciidtype_getitem(ASCIIDTypeObject *descr, char *dataptr) |
115 | 140 | { |
116 | | - char *val = NULL; |
117 | | - /* get the value */ |
118 | | - memcpy(val, dataptr, descr->size * sizeof(char)); // NOLINT |
119 | | - |
120 | | - PyObject *val_obj = PyUnicode_FromStringAndSize(val, descr->size); |
| 141 | + PyObject *val_obj = PyUnicode_FromString(dataptr); |
121 | 142 | if (val_obj == NULL) { |
122 | 143 | return NULL; |
123 | 144 | } |
|
0 commit comments