@@ -947,12 +947,14 @@ def rank_1d(
947947 TiebreakEnumType tiebreak
948948 Py_ssize_t i, j, N, grp_start= 0 , dups= 0 , sum_ranks= 0
949949 Py_ssize_t grp_vals_seen= 1 , grp_na_count= 0
950- ndarray[int64_t, ndim= 1 ] lexsort_indexer
951- ndarray[float64_t, ndim= 1 ] grp_sizes, out
950+ ndarray[int64_t, ndim= 1 ] grp_sizes
951+ ndarray[intp_t, ndim= 1 ] lexsort_indexer
952+ ndarray[float64_t, ndim= 1 ] out
952953 ndarray[rank_t, ndim= 1 ] masked_vals
953954 ndarray[uint8_t, ndim= 1 ] mask
954955 bint keep_na, at_end, next_val_diff, check_labels, group_changed
955956 rank_t nan_fill_val
957+ int64_t grp_size
956958
957959 tiebreak = tiebreakers[ties_method]
958960 if tiebreak == TIEBREAK_FIRST:
@@ -965,7 +967,7 @@ def rank_1d(
965967 # TODO Cython 3.0: cast won't be necessary (#2992)
966968 assert < Py_ssize_t> len (labels) == N
967969 out = np.empty(N)
968- grp_sizes = np.ones(N)
970+ grp_sizes = np.ones(N, dtype = np.int64 )
969971
970972 # If all 0 labels, can short-circuit later label
971973 # comparisons
@@ -1022,7 +1024,7 @@ def rank_1d(
10221024 # each label corresponds to a different group value,
10231025 # the mask helps you differentiate missing values before
10241026 # performing sort on the actual values
1025- lexsort_indexer = np.lexsort(order).astype(np.int64 , copy = False )
1027+ lexsort_indexer = np.lexsort(order).astype(np.intp , copy = False )
10261028
10271029 if not ascending:
10281030 lexsort_indexer = lexsort_indexer[::- 1 ]
@@ -1093,13 +1095,15 @@ def rank_1d(
10931095 for j in range (i - dups + 1 , i + 1 ):
10941096 out[lexsort_indexer[j]] = grp_vals_seen
10951097
1096- # Look forward to the next value (using the sorting in lexsort_indexer)
1097- # if the value does not equal the current value then we need to
1098- # reset the dups and sum_ranks, knowing that a new value is
1099- # coming up. The conditional also needs to handle nan equality
1100- # and the end of iteration
1101- if next_val_diff or (mask[lexsort_indexer[i]]
1102- ^ mask[lexsort_indexer[i+ 1 ]]):
1098+ # Look forward to the next value (using the sorting in
1099+ # lexsort_indexer). If the value does not equal the current
1100+ # value then we need to reset the dups and sum_ranks, knowing
1101+ # that a new value is coming up. The conditional also needs
1102+ # to handle nan equality and the end of iteration. If group
1103+ # changes we do not record seeing a new value in the group
1104+ if not group_changed and (next_val_diff or
1105+ (mask[lexsort_indexer[i]]
1106+ ^ mask[lexsort_indexer[i+ 1 ]])):
11031107 dups = sum_ranks = 0
11041108 grp_vals_seen += 1
11051109
@@ -1110,14 +1114,21 @@ def rank_1d(
11101114 # group encountered (used by pct calculations later). Also be
11111115 # sure to reset any of the items helping to calculate dups
11121116 if group_changed:
1117+
1118+ # If not dense tiebreak, group size used to compute
1119+ # percentile will be # of non-null elements in group
11131120 if tiebreak != TIEBREAK_DENSE:
1114- for j in range (grp_start, i + 1 ):
1115- grp_sizes[lexsort_indexer[j]] = \
1116- (i - grp_start + 1 - grp_na_count)
1121+ grp_size = i - grp_start + 1 - grp_na_count
1122+
1123+ # Otherwise, it will be the number of distinct values
1124+ # in the group, subtracting 1 if NaNs are present
1125+ # since that is a distinct value we shouldn't count
11171126 else :
1118- for j in range (grp_start, i + 1 ):
1119- grp_sizes[lexsort_indexer[j]] = \
1120- (grp_vals_seen - 1 - (grp_na_count > 0 ))
1127+ grp_size = grp_vals_seen - (grp_na_count > 0 )
1128+
1129+ for j in range (grp_start, i + 1 ):
1130+ grp_sizes[lexsort_indexer[j]] = grp_size
1131+
11211132 dups = sum_ranks = 0
11221133 grp_na_count = 0
11231134 grp_start = i + 1
@@ -1184,12 +1195,14 @@ def rank_1d(
11841195 out[lexsort_indexer[j]] = grp_vals_seen
11851196
11861197 # Look forward to the next value (using the sorting in
1187- # lexsort_indexer) if the value does not equal the current
1198+ # lexsort_indexer). If the value does not equal the current
11881199 # value then we need to reset the dups and sum_ranks, knowing
11891200 # that a new value is coming up. The conditional also needs
1190- # to handle nan equality and the end of iteration
1191- if next_val_diff or (mask[lexsort_indexer[i]]
1192- ^ mask[lexsort_indexer[i+ 1 ]]):
1201+ # to handle nan equality and the end of iteration. If group
1202+ # changes we do not record seeing a new value in the group
1203+ if not group_changed and (next_val_diff or
1204+ (mask[lexsort_indexer[i]]
1205+ ^ mask[lexsort_indexer[i+ 1 ]])):
11931206 dups = sum_ranks = 0
11941207 grp_vals_seen += 1
11951208
@@ -1200,14 +1213,21 @@ def rank_1d(
12001213 # group encountered (used by pct calculations later). Also be
12011214 # sure to reset any of the items helping to calculate dups
12021215 if group_changed:
1216+
1217+ # If not dense tiebreak, group size used to compute
1218+ # percentile will be # of non-null elements in group
12031219 if tiebreak != TIEBREAK_DENSE:
1204- for j in range (grp_start, i + 1 ):
1205- grp_sizes[lexsort_indexer[j]] = \
1206- (i - grp_start + 1 - grp_na_count)
1220+ grp_size = i - grp_start + 1 - grp_na_count
1221+
1222+ # Otherwise, it will be the number of distinct values
1223+ # in the group, subtracting 1 if NaNs are present
1224+ # since that is a distinct value we shouldn't count
12071225 else :
1208- for j in range (grp_start, i + 1 ):
1209- grp_sizes[lexsort_indexer[j]] = \
1210- (grp_vals_seen - 1 - (grp_na_count > 0 ))
1226+ grp_size = grp_vals_seen - (grp_na_count > 0 )
1227+
1228+ for j in range (grp_start, i + 1 ):
1229+ grp_sizes[lexsort_indexer[j]] = grp_size
1230+
12111231 dups = sum_ranks = 0
12121232 grp_na_count = 0
12131233 grp_start = i + 1
0 commit comments