@@ -49,7 +49,7 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[],
4949
5050static NPY_CASTING
5151string_equal_resolve_descriptors (PyObject * NPY_UNUSED (self ),
52- PyArray_DTypeMeta * dtypes [],
52+ PyArray_DTypeMeta * NPY_UNUSED ( dtypes []) ,
5353 PyArray_Descr * given_descrs [],
5454 PyArray_Descr * loop_descrs [],
5555 npy_intp * NPY_UNUSED (view_offset ))
@@ -61,7 +61,42 @@ string_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
6161
6262 loop_descrs [2 ] = PyArray_DescrFromType (NPY_BOOL ); // cannot fail
6363
64- return NPY_SAFE_CASTING ;
64+ return NPY_NO_CASTING ;
65+ }
66+
67+ static int
68+ string_isnan_strided_loop (PyArrayMethod_Context * NPY_UNUSED (context ),
69+ char * const data [], npy_intp const dimensions [],
70+ npy_intp const strides [],
71+ NpyAuxData * NPY_UNUSED (auxdata ))
72+ {
73+ npy_intp N = dimensions [0 ];
74+ npy_bool * out = (npy_bool * )data [1 ];
75+ npy_intp out_stride = strides [1 ];
76+
77+ while (N -- ) {
78+ // we could represent missing data with a null pointer, but
79+ // should isnan return True in that case?
80+ * out = (npy_bool )0 ;
81+
82+ out += out_stride ;
83+ }
84+
85+ return 0 ;
86+ }
87+
88+ static NPY_CASTING
89+ string_isnan_resolve_descriptors (PyObject * NPY_UNUSED (self ),
90+ PyArray_DTypeMeta * NPY_UNUSED (dtypes []),
91+ PyArray_Descr * given_descrs [],
92+ PyArray_Descr * loop_descrs [],
93+ npy_intp * NPY_UNUSED (view_offset ))
94+ {
95+ Py_INCREF (given_descrs [0 ]);
96+ loop_descrs [0 ] = given_descrs [0 ];
97+ loop_descrs [1 ] = PyArray_DescrFromType (NPY_BOOL ); // cannot fail
98+
99+ return NPY_NO_CASTING ;
65100}
66101
67102/*
@@ -131,73 +166,70 @@ default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
131166}
132167
133168int
134- init_equal_ufunc (PyObject * numpy )
169+ init_ufunc (PyObject * numpy , const char * ufunc_name , PyArray_DTypeMeta * * dtypes ,
170+ resolve_descriptors_function * resolve_func ,
171+ PyArrayMethod_StridedLoop * loop_func , const char * loop_name ,
172+ int nin , int nout , NPY_CASTING casting , NPY_ARRAYMETHOD_FLAGS flags )
135173{
136- PyObject * equal = PyObject_GetAttrString (numpy , "equal" );
137- if (equal == NULL ) {
174+ PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
175+ if (ufunc == NULL ) {
138176 return -1 ;
139177 }
140178
141179 /*
142180 * Initialize spec for equality
143181 */
144- PyArray_DTypeMeta * eq_dtypes [3 ] = {& StringDType , & StringDType ,
145- & PyArray_BoolDType };
146-
147- static PyType_Slot eq_slots [] = {
148- {NPY_METH_resolve_descriptors , & string_equal_resolve_descriptors },
149- {NPY_METH_strided_loop , & string_equal_strided_loop },
150- {0 , NULL }};
151-
152- PyArrayMethod_Spec EqualSpec = {
153- .name = "string_equal" ,
154- .nin = 2 ,
155- .nout = 1 ,
156- .casting = NPY_NO_CASTING ,
157- .flags = 0 ,
158- .dtypes = eq_dtypes ,
159- .slots = eq_slots ,
182+ PyType_Slot slots [] = {{NPY_METH_resolve_descriptors , resolve_func },
183+ {NPY_METH_strided_loop , loop_func },
184+ {0 , NULL }};
185+
186+ PyArrayMethod_Spec spec = {
187+ .name = loop_name ,
188+ .nin = nin ,
189+ .nout = nout ,
190+ .casting = casting ,
191+ .flags = flags ,
192+ .dtypes = dtypes ,
193+ .slots = slots ,
160194 };
161195
162- if (PyUFunc_AddLoopFromSpec (equal , & EqualSpec ) < 0 ) {
163- Py_DECREF (equal );
196+ if (PyUFunc_AddLoopFromSpec (ufunc , & spec ) < 0 ) {
197+ Py_DECREF (ufunc );
164198 return -1 ;
165199 }
166200
167- /*
168- * Add promoter to ufunc, ensures operations that mix StringDType and
169- * UnicodeDType cast the unicode argument to string.
170- */
201+ Py_DECREF (ufunc );
202+ return 0 ;
203+ }
171204
172- PyObject * DTypes [] = {
173- PyTuple_Pack (3 , & StringDType , & PyArray_UnicodeDType ,
174- & PyArray_BoolDType ),
175- PyTuple_Pack (3 , & PyArray_UnicodeDType , & StringDType ,
176- & PyArray_BoolDType ),
177- };
205+ int
206+ add_promoter (PyObject * numpy , const char * ufunc_name ,
207+ PyArray_DTypeMeta * * dtypes )
208+ {
209+ PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
210+ if (ufunc == NULL ) {
211+ return -1 ;
212+ }
178213
179- if ((DTypes [0 ] == NULL ) || (DTypes [1 ] == NULL )) {
180- Py_DECREF (equal );
214+ PyObject * DType_tuple = PyTuple_Pack (3 , dtypes [0 ], dtypes [1 ], dtypes [2 ]);
215+ if (DType_tuple == NULL ) {
216+ Py_DECREF (ufunc );
181217 return -1 ;
182218 }
183219
184220 PyObject * promoter_capsule = PyCapsule_New ((void * )& default_ufunc_promoter ,
185221 "numpy._ufunc_promoter" , NULL );
186222
187- for (int i = 0 ; i < 2 ; i ++ ) {
188- if (PyUFunc_AddPromoter (equal , DTypes [i ], promoter_capsule ) < 0 ) {
189- Py_DECREF (promoter_capsule );
190- Py_DECREF (DTypes [0 ]);
191- Py_DECREF (DTypes [1 ]);
192- Py_DECREF (equal );
193- return -1 ;
194- }
223+ if (PyUFunc_AddPromoter (ufunc , DType_tuple , promoter_capsule ) < 0 ) {
224+ Py_DECREF (promoter_capsule );
225+ Py_DECREF (DType_tuple );
226+ Py_DECREF (ufunc );
227+ return -1 ;
195228 }
196229
197230 Py_DECREF (promoter_capsule );
198- Py_DECREF (DTypes [0 ]);
199- Py_DECREF (DTypes [1 ]);
200- Py_DECREF (equal );
231+ Py_DECREF (DType_tuple );
232+ Py_DECREF (ufunc );
201233
202234 return 0 ;
203235}
@@ -210,7 +242,35 @@ init_ufuncs(void)
210242 return -1 ;
211243 }
212244
213- if (init_equal_ufunc (numpy ) < 0 ) {
245+ PyArray_DTypeMeta * eq_dtypes [] = {& StringDType , & StringDType ,
246+ & PyArray_BoolDType };
247+
248+ if (init_ufunc (numpy , "equal" , eq_dtypes ,
249+ & string_equal_resolve_descriptors ,
250+ & string_equal_strided_loop , "string_equal" , 2 , 1 ,
251+ NPY_NO_CASTING , 0 ) < 0 ) {
252+ goto error ;
253+ }
254+
255+ PyArray_DTypeMeta * promoter_dtypes [2 ][3 ] = {
256+ {& StringDType , & PyArray_UnicodeDType , & PyArray_BoolDType },
257+ {& PyArray_UnicodeDType , & StringDType , & PyArray_BoolDType },
258+ };
259+
260+ if (add_promoter (numpy , "equal" , promoter_dtypes [0 ]) < 0 ) {
261+ goto error ;
262+ }
263+
264+ if (add_promoter (numpy , "equal" , promoter_dtypes [1 ]) < 0 ) {
265+ goto error ;
266+ }
267+
268+ PyArray_DTypeMeta * isnan_dtypes [] = {& StringDType , & PyArray_BoolDType };
269+
270+ if (init_ufunc (numpy , "isnan" , isnan_dtypes ,
271+ & string_isnan_resolve_descriptors ,
272+ & string_isnan_strided_loop , "string_isnan" , 1 , 1 ,
273+ NPY_NO_CASTING , 0 ) < 0 ) {
214274 goto error ;
215275 }
216276
0 commit comments