@@ -8,7 +8,7 @@ import numpy as np
88cimport numpy as cnp
99from numpy cimport (ndarray,
1010 int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t,
11- uint32_t, uint64_t, float32_t, float64_t)
11+ uint32_t, uint64_t, float32_t, float64_t, complex64_t, complex128_t )
1212cnp.import_array()
1313
1414
@@ -421,30 +421,38 @@ def group_any_all(uint8_t[:] out,
421421 if values[i] == flag_val:
422422 out[lab] = flag_val
423423
424+
424425# ----------------------------------------------------------------------
425426# group_add, group_prod, group_var, group_mean, group_ohlc
426427# ----------------------------------------------------------------------
427428
429+ ctypedef fused complexfloating_t:
430+ float64_t
431+ float32_t
432+ complex64_t
433+ complex128_t
434+
428435
429436@ cython.wraparound (False )
430437@ cython.boundscheck (False )
431- def _group_add (floating [:, :] out ,
438+ def _group_add (complexfloating_t [:, :] out ,
432439 int64_t[:] counts ,
433- floating [:, :] values ,
440+ complexfloating_t [:, :] values ,
434441 const int64_t[:] labels ,
435442 Py_ssize_t min_count = 0 ):
436443 """
437444 Only aggregates on axis=0
438445 """
439446 cdef:
440447 Py_ssize_t i, j, N, K, lab, ncounts = len (counts)
441- floating val, count
442- floating[:, :] sumx, nobs
448+ complexfloating_t val, count
449+ complexfloating_t[:, :] sumx
450+ int64_t[:, :] nobs
443451
444452 if len (values) != len (labels):
445453 raise ValueError (" len(index) != len(labels)" )
446454
447- nobs = np.zeros_like( out)
455+ nobs = np.zeros(( len ( out), out.shape[ 1 ]), dtype = np.int64 )
448456 sumx = np.zeros_like(out)
449457
450458 N, K = (< object > values).shape
@@ -462,7 +470,12 @@ def _group_add(floating[:, :] out,
462470 # not nan
463471 if val == val:
464472 nobs[lab, j] += 1
465- sumx[lab, j] += val
473+ if (complexfloating_t is complex64_t or
474+ complexfloating_t is complex128_t):
475+ # clang errors if we use += with these dtypes
476+ sumx[lab, j] = sumx[lab, j] + val
477+ else :
478+ sumx[lab, j] += val
466479
467480 for i in range (ncounts):
468481 for j in range (K):
@@ -472,8 +485,10 @@ def _group_add(floating[:, :] out,
472485 out[i, j] = sumx[i, j]
473486
474487
475- group_add_float32 = _group_add[' float' ]
476- group_add_float64 = _group_add[' double' ]
488+ group_add_float32 = _group_add[' float32_t' ]
489+ group_add_float64 = _group_add[' float64_t' ]
490+ group_add_complex64 = _group_add[' float complex' ]
491+ group_add_complex128 = _group_add[' double complex' ]
477492
478493
479494@ cython.wraparound (False )
0 commit comments