@@ -1066,6 +1066,8 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
10661066 int64_t[::1] counts ,
10671067 ndarray[iu_64_floating_obj_t , ndim = 2 ] values,
10681068 const intp_t[::1] labels ,
1069+ const uint8_t[:, :] mask ,
1070+ uint8_t[:, ::1] result_mask = None ,
10691071 int64_t min_count = - 1 ,
10701072 int64_t rank = 1 ,
10711073 ) -> None:
@@ -1078,6 +1080,8 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
10781080 ndarray[iu_64_floating_obj_t , ndim = 2 ] resx
10791081 ndarray[int64_t , ndim = 2 ] nobs
10801082 bint runtime_error = False
1083+ bint uses_mask = mask is not None
1084+ bint isna_entry
10811085
10821086 # TODO(cython3 ):
10831087 # Instead of `labels.shape[0]` use `len(labels)`
@@ -1104,7 +1108,12 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
11041108 for j in range (K):
11051109 val = values[i, j]
11061110
1107- if not checknull(val):
1111+ if uses_mask:
1112+ isna_entry = mask[i, j]
1113+ else :
1114+ isna_entry = checknull(val)
1115+
1116+ if not isna_entry:
11081117 # NB: use _treat_as_na here once
11091118 # conditional-nogil is available.
11101119 nobs[lab, j] += 1
@@ -1129,16 +1138,24 @@ def group_nth(iu_64_floating_obj_t[:, ::1] out,
11291138 for j in range (K):
11301139 val = values[i, j]
11311140
1132- if not _treat_as_na(val, True ):
1141+ if uses_mask:
1142+ isna_entry = mask[i, j]
1143+ else :
1144+ isna_entry = _treat_as_na(val, True )
11331145 # TODO: Sure we always want is_datetimelike=True?
1146+
1147+ if not isna_entry:
11341148 nobs[lab, j] += 1
11351149 if nobs[lab, j] == rank:
11361150 resx[lab, j] = val
11371151
11381152 for i in range (ncounts):
11391153 for j in range (K):
11401154 if nobs[i, j] < min_count:
1141- if iu_64_floating_obj_t is int64_t:
1155+ if uses_mask:
1156+ result_mask[i, j] = True
1157+ elif iu_64_floating_obj_t is int64_t:
1158+ # TODO: only if datetimelike?
11421159 out[i, j] = NPY_NAT
11431160 elif iu_64_floating_obj_t is uint64_t:
11441161 runtime_error = True
0 commit comments