@@ -41,6 +41,7 @@ from pandas._libs.algos import (
4141 ensure_platform_int,
4242 groupsort_indexer,
4343 rank_1d,
44+ take_2d_axis1_bool_bool,
4445 take_2d_axis1_float64_float64,
4546)
4647
@@ -64,11 +65,48 @@ cdef enum InterpolationEnumType:
6465 INTERPOLATION_MIDPOINT
6566
6667
67- cdef inline float64_t median_linear (float64_t* a, int n) nogil:
68+ cdef inline float64_t median_linear_mask (float64_t* a, int n, uint8_t * mask ) nogil:
6869 cdef:
6970 int i, j, na_count = 0
71+ float64_t* tmp
7072 float64_t result
73+
74+ if n == 0 :
75+ return NaN
76+
77+ # count NAs
78+ for i in range (n):
79+ if mask[i]:
80+ na_count += 1
81+
82+ if na_count:
83+ if na_count == n:
84+ return NaN
85+
86+ tmp = < float64_t* > malloc((n - na_count) * sizeof(float64_t))
87+
88+ j = 0
89+ for i in range (n):
90+ if not mask[i]:
91+ tmp[j] = a[i]
92+ j += 1
93+
94+ a = tmp
95+ n -= na_count
96+
97+ result = calc_median_linear(a, n, na_count)
98+
99+ if na_count:
100+ free(a)
101+
102+ return result
103+
104+
105+ cdef inline float64_t median_linear(float64_t* a, int n) nogil:
106+ cdef:
107+ int i, j, na_count = 0
71108 float64_t* tmp
109+ float64_t result
72110
73111 if n == 0 :
74112 return NaN
@@ -93,18 +131,34 @@ cdef inline float64_t median_linear(float64_t* a, int n) nogil:
93131 a = tmp
94132 n -= na_count
95133
134+ result = calc_median_linear(a, n, na_count)
135+
136+ if na_count:
137+ free(a)
138+
139+ return result
140+
141+
142+ cdef inline float64_t calc_median_linear(float64_t* a, int n, int na_count) nogil:
143+ cdef:
144+ float64_t result
145+
96146 if n % 2 :
97147 result = kth_smallest_c(a, n // 2 , n)
98148 else :
99149 result = (kth_smallest_c(a, n // 2 , n) +
100150 kth_smallest_c(a, n // 2 - 1 , n)) / 2
101151
102- if na_count:
103- free(a)
104-
105152 return result
106153
107154
155+ ctypedef fused int64float_t:
156+ int64_t
157+ uint64_t
158+ float32_t
159+ float64_t
160+
161+
108162@ cython.boundscheck (False )
109163@ cython.wraparound (False )
110164def group_median_float64 (
@@ -113,6 +167,8 @@ def group_median_float64(
113167 ndarray[float64_t , ndim = 2 ] values,
114168 ndarray[intp_t] labels ,
115169 Py_ssize_t min_count = - 1 ,
170+ const uint8_t[:, :] mask = None ,
171+ uint8_t[:, ::1] result_mask = None ,
116172) -> None:
117173 """
118174 Only aggregates on axis = 0
@@ -121,8 +177,12 @@ def group_median_float64(
121177 Py_ssize_t i , j , N , K , ngroups , size
122178 ndarray[intp_t] _counts
123179 ndarray[float64_t , ndim = 2 ] data
180+ ndarray[uint8_t , ndim = 2 ] data_mask
124181 ndarray[intp_t] indexer
125182 float64_t* ptr
183+ uint8_t* ptr_mask
184+ float64_t result
185+ bint uses_mask = mask is not None
126186
127187 assert min_count == -1, "'min_count' only used in sum and prod"
128188
@@ -137,15 +197,38 @@ def group_median_float64(
137197
138198 take_2d_axis1_float64_float64(values.T , indexer , out = data)
139199
140- with nogil:
200+ if uses_mask:
201+ data_mask = np.empty((K, N), dtype = np.uint8)
202+ ptr_mask = < uint8_t * > cnp.PyArray_DATA(data_mask)
203+
204+ take_2d_axis1_bool_bool(mask.T , indexer , out = data_mask, fill_value = 1 )
141205
142- for i in range(K ):
143- # exclude NA group
144- ptr += _counts[0 ]
145- for j in range (ngroups):
146- size = _counts[j + 1 ]
147- out[j, i] = median_linear(ptr, size)
148- ptr += size
206+ with nogil:
207+
208+ for i in range(K ):
209+ # exclude NA group
210+ ptr += _counts[0 ]
211+ ptr_mask += _counts[0 ]
212+
213+ for j in range (ngroups):
214+ size = _counts[j + 1 ]
215+ result = median_linear_mask(ptr, size, ptr_mask)
216+ out[j, i] = result
217+
218+ if result != result:
219+ result_mask[j, i] = 1
220+ ptr += size
221+ ptr_mask += size
222+
223+ else :
224+ with nogil:
225+ for i in range (K):
226+ # exclude NA group
227+ ptr += _counts[0 ]
228+ for j in range (ngroups):
229+ size = _counts[j + 1 ]
230+ out[j, i] = median_linear(ptr, size)
231+ ptr += size
149232
150233
151234@ cython.boundscheck (False )
@@ -206,13 +289,6 @@ def group_cumprod_float64(
206289 accum[lab, j] = NaN
207290
208291
209- ctypedef fused int64float_t:
210- int64_t
211- uint64_t
212- float32_t
213- float64_t
214-
215-
216292@ cython.boundscheck (False )
217293@ cython.wraparound (False )
218294def group_cumsum (
0 commit comments