11#include "casts.h"
22
33#include "dtype.h"
4+ #include "static_string.h"
5+
6+ void
7+ gil_error (PyObject * type , const char * msg )
8+ {
9+ PyGILState_STATE gstate ;
10+ gstate = PyGILState_Ensure ();
11+ PyErr_SetString (type , msg );
12+ PyGILState_Release (gstate );
13+ }
414
515static NPY_CASTING
616string_to_string_resolve_descriptors (PyObject * NPY_UNUSED (self ),
@@ -35,17 +45,19 @@ string_to_string(PyArrayMethod_Context *context, char *const data[],
3545 NpyAuxData * NPY_UNUSED (auxdata ))
3646{
3747 npy_intp N = dimensions [0 ];
38- char * * in = (char * * )data [0 ];
39- char * * out = (char * * )data [1 ];
48+ ss * * in = (ss * * )data [0 ];
49+ ss * * out = (ss * * )data [1 ];
4050 // strides are in bytes but pointer offsets are in pointer widths, so
4151 // divide by the element size (one pointer width) to get the pointer offset
4252 npy_intp in_stride = strides [0 ] / context -> descriptors [0 ]-> elsize ;
4353 npy_intp out_stride = strides [1 ] / context -> descriptors [1 ]-> elsize ;
4454
4555 while (N -- ) {
46- size_t length = strlen (* in );
47- out [0 ] = (char * )malloc ((sizeof (char ) * length ) + 1 );
48- strncpy (* out , * in , length + 1 );
56+ out [0 ] = ssdup (in [0 ]);
57+ if (out [0 ] == NULL ) {
58+ gil_error (PyExc_MemoryError , "ssdup failed" );
59+ return -1 ;
60+ }
4961 in += in_stride ;
5062 out += out_stride ;
5163 }
@@ -189,7 +201,7 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
189201
190202 npy_intp N = dimensions [0 ];
191203 Py_UCS4 * in = (Py_UCS4 * )data [0 ];
192- char * * out = (char * * )data [1 ];
204+ ss * * out = (ss * * )data [1 ];
193205
194206 // 4 bytes per UCS4 character
195207 npy_intp in_stride = strides [0 ] / 4 ;
@@ -202,16 +214,14 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
202214 size_t num_codepoints = 0 ;
203215 if (utf8_size (in , max_in_size , & num_codepoints , & out_num_bytes ) ==
204216 -1 ) {
205- // invalid codepoint found so acquire GIL, set error, return
206- PyGILState_STATE gstate ;
207- gstate = PyGILState_Ensure ();
208- PyErr_SetString (PyExc_TypeError ,
209- "Invalid unicode code point found" );
210- PyGILState_Release (gstate );
217+ gil_error (PyExc_TypeError , "Invalid unicode code point found" );
211218 return -1 ;
212219 }
213- // one extra byte for null terminator
214- char * out_buf = malloc ((out_num_bytes + 1 ) * sizeof (char ));
220+ ss * out_ss = ssnewempty (out_num_bytes );
221+ if (out_ss == NULL ) {
222+ gil_error (PyExc_MemoryError , "ssnewempty failed" );
223+ }
224+ char * out_buf = out_ss -> buf ;
215225 for (int i = 0 ; i < num_codepoints ; i ++ ) {
216226 // get code point
217227 Py_UCS4 code = in [i ];
@@ -237,7 +247,7 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
237247 out_buf [out_num_bytes ] = '\0' ;
238248
239249 // set out to the address of the beginning of the string
240- out [0 ] = out_buf ;
250+ out [0 ] = out_ss ;
241251
242252 in += in_stride ;
243253 out += out_stride ;
@@ -318,7 +328,7 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
318328 NpyAuxData * NPY_UNUSED (auxdata ))
319329{
320330 npy_intp N = dimensions [0 ];
321- char * * in = (char * * )data [0 ];
331+ ss * * in = (ss * * )data [0 ];
322332 Py_UCS4 * out = (Py_UCS4 * )data [1 ];
323333 // strides are in bytes but pointer offsets are in pointer widths, so
324334 // divide by the element size (one pointer width) to get the pointer offset
@@ -329,7 +339,9 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
329339 long max_out_size = (context -> descriptors [1 ]-> elsize ) / 4 ;
330340
331341 while (N -- ) {
332- unsigned char * this_string = (unsigned char * )* in ;
342+ unsigned char * this_string = (unsigned char * )((* in )-> buf );
343+ size_t n_bytes = (* in )-> len ;
344+ size_t tot_n_bytes = 0 ;
333345
334346 for (int i = 0 ; i < max_out_size ; i ++ ) {
335347 Py_UCS4 code ;
@@ -340,16 +352,13 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
340352
341353 // move to next character
342354 this_string += num_bytes ;
355+ tot_n_bytes += num_bytes ;
343356
344357 // set output codepoint
345358 out [i ] = code ;
346359
347- // check if this is the null terminator
348- if (code == 0 ) {
349- // fill all remaining characters (if any) with zero
350- for (int j = i + 1 ; j < max_out_size ; j ++ ) {
351- out [j ] = 0 ;
352- }
360+ // stop if we've exhausted the input string
361+ if (tot_n_bytes >= n_bytes ) {
353362 break ;
354363 }
355364 }
0 commit comments