@@ -995,22 +995,27 @@ def rank_1d(
995995 cdef:
996996 TiebreakEnumType tiebreak
997997 Py_ssize_t i, j, N, grp_start= 0 , dups= 0 , sum_ranks= 0
998- Py_ssize_t grp_vals_seen= 1 , grp_na_count= 0 , grp_tie_count = 0
998+ Py_ssize_t grp_vals_seen= 1 , grp_na_count= 0
999999 ndarray[int64_t, ndim= 1 ] lexsort_indexer
10001000 ndarray[float64_t, ndim= 1 ] grp_sizes, out
10011001 ndarray[rank_t, ndim= 1 ] masked_vals
10021002 ndarray[uint8_t, ndim= 1 ] mask
1003- bint keep_na, at_end, next_val_diff, check_labels
1003+ bint keep_na, at_end, next_val_diff, check_labels, group_changed
10041004 rank_t nan_fill_val
10051005
10061006 tiebreak = tiebreakers[ties_method]
1007+ if tiebreak == TIEBREAK_FIRST:
1008+ if not ascending:
1009+ tiebreak = TIEBREAK_FIRST_DESCENDING
1010+
10071011 keep_na = na_option == ' keep'
10081012
10091013 N = len (values)
10101014 # TODO Cython 3.0: cast won't be necessary (#2992)
10111015 assert < Py_ssize_t> len (labels) == N
10121016 out = np.empty(N)
10131017 grp_sizes = np.ones(N)
1018+
10141019 # If all 0 labels, can short-circuit later label
10151020 # comparisons
10161021 check_labels = np.any(labels)
@@ -1032,6 +1037,12 @@ def rank_1d(
10321037 else :
10331038 mask = np.zeros(shape = len (masked_vals), dtype = np.uint8)
10341039
1040+ # If `na_option == 'top'`, we want to assign the lowest rank
1041+ # to NaN regardless of ascending/descending. So if ascending,
1042+ # fill with lowest value of type to end up with lowest rank.
1043+ # If descending, fill with highest value since descending
1044+ # will flip the ordering to still end up with lowest rank.
1045+ # Symmetric logic applies to `na_option == 'bottom'`
10351046 if ascending ^ (na_option == ' top' ):
10361047 if rank_t is object :
10371048 nan_fill_val = Infinity()
@@ -1074,36 +1085,36 @@ def rank_1d(
10741085 if rank_t is object :
10751086 for i in range (N):
10761087 at_end = i == N - 1
1088+
10771089 # dups and sum_ranks will be incremented each loop where
10781090 # the value / group remains the same, and should be reset
1079- # when either of those change
1080- # Used to calculate tiebreakers
1091+ # when either of those change. Used to calculate tiebreakers
10811092 dups += 1
10821093 sum_ranks += i - grp_start + 1
10831094
1095+ next_val_diff = at_end or are_diff(masked_vals[lexsort_indexer[i]],
1096+ masked_vals[lexsort_indexer[i+ 1 ]])
1097+
1098+ # We'll need this check later anyway to determine group size, so just
1099+ # compute it here since shortcircuiting won't help
1100+ group_changed = at_end or (check_labels and
1101+ (labels[lexsort_indexer[i]]
1102+ != labels[lexsort_indexer[i+ 1 ]]))
1103+
10841104 # Update out only when there is a transition of values or labels.
10851105 # When a new value or group is encountered, go back #dups steps(
10861106 # the number of occurrence of current value) and assign the ranks
10871107 # based on the starting index of the current group (grp_start)
10881108 # and the current index
1089- if not at_end:
1090- next_val_diff = are_diff(masked_vals[lexsort_indexer[i]],
1091- masked_vals[lexsort_indexer[i+ 1 ]])
1092- else :
1093- next_val_diff = True
1094-
1095- if (next_val_diff
1096- or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+ 1 ]])
1097- or (check_labels
1098- and (labels[lexsort_indexer[i]]
1099- != labels[lexsort_indexer[i+ 1 ]]))
1100- ):
1101- # if keep_na, check for missing values and assign back
1109+ if (next_val_diff or group_changed
1110+ or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+ 1 ]])):
1111+
1112+ # If keep_na, check for missing values and assign back
11021113 # to the result where appropriate
11031114 if keep_na and mask[lexsort_indexer[i]]:
1115+ grp_na_count = dups
11041116 for j in range (i - dups + 1 , i + 1 ):
11051117 out[lexsort_indexer[j]] = NaN
1106- grp_na_count = dups
11071118 elif tiebreak == TIEBREAK_AVERAGE:
11081119 for j in range (i - dups + 1 , i + 1 ):
11091120 out[lexsort_indexer[j]] = sum_ranks / < float64_t> dups
@@ -1113,84 +1124,87 @@ def rank_1d(
11131124 elif tiebreak == TIEBREAK_MAX:
11141125 for j in range (i - dups + 1 , i + 1 ):
11151126 out[lexsort_indexer[j]] = i - grp_start + 1
1127+
1128+ # With n as the previous rank in the group and m as the number
1129+ # of duplicates in this stretch, if TIEBREAK_FIRST and ascending,
1130+ # then rankings should be n + 1, n + 2 ... n + m
11161131 elif tiebreak == TIEBREAK_FIRST:
11171132 for j in range (i - dups + 1 , i + 1 ):
1118- if ascending:
1119- out[lexsort_indexer[j]] = j + 1 - grp_start
1120- else :
1121- out[lexsort_indexer[j]] = 2 * i - j - dups + 2 - grp_start
1133+ out[lexsort_indexer[j]] = j + 1 - grp_start
1134+
1135+ # If TIEBREAK_FIRST and descending, the ranking should be
1136+ # n + m, n + (m - 1) ... n + 1. This is equivalent to
1137+ # (i - dups + 1) + (i - j + 1) - grp_start
1138+ elif tiebreak == TIEBREAK_FIRST_DESCENDING:
1139+ for j in range (i - dups + 1 , i + 1 ):
1140+ out[lexsort_indexer[j]] = 2 * i - j - dups + 2 - grp_start
11221141 elif tiebreak == TIEBREAK_DENSE:
11231142 for j in range (i - dups + 1 , i + 1 ):
11241143 out[lexsort_indexer[j]] = grp_vals_seen
11251144
1126- # look forward to the next value (using the sorting in _as )
1145+ # Look forward to the next value (using the sorting in lexsort_indexer )
11271146 # if the value does not equal the current value then we need to
11281147 # reset the dups and sum_ranks, knowing that a new value is
1129- # coming up. the conditional also needs to handle nan equality
1148+ # coming up. The conditional also needs to handle nan equality
11301149 # and the end of iteration
11311150 if next_val_diff or (mask[lexsort_indexer[i]]
11321151 ^ mask[lexsort_indexer[i+ 1 ]]):
11331152 dups = sum_ranks = 0
11341153 grp_vals_seen += 1
1135- grp_tie_count += 1
11361154
11371155 # Similar to the previous conditional, check now if we are
11381156 # moving to a new group. If so, keep track of the index where
11391157 # the new group occurs, so the tiebreaker calculations can
1140- # decrement that from their position. fill in the size of each
1141- # group encountered (used by pct calculations later). also be
1158+ # decrement that from their position. Fill in the size of each
1159+ # group encountered (used by pct calculations later). Also be
11421160 # sure to reset any of the items helping to calculate dups
1143- if (at_end or
1144- (check_labels
1145- and (labels[lexsort_indexer[i]]
1146- != labels[lexsort_indexer[i+ 1 ]]))):
1161+ if group_changed:
11471162 if tiebreak != TIEBREAK_DENSE:
11481163 for j in range (grp_start, i + 1 ):
11491164 grp_sizes[lexsort_indexer[j]] = \
11501165 (i - grp_start + 1 - grp_na_count)
11511166 else :
11521167 for j in range (grp_start, i + 1 ):
11531168 grp_sizes[lexsort_indexer[j]] = \
1154- (grp_tie_count - (grp_na_count > 0 ))
1169+ (grp_vals_seen - 1 - (grp_na_count > 0 ))
11551170 dups = sum_ranks = 0
11561171 grp_na_count = 0
1157- grp_tie_count = 0
11581172 grp_start = i + 1
11591173 grp_vals_seen = 1
11601174 else :
11611175 with nogil:
11621176 for i in range (N):
11631177 at_end = i == N - 1
1178+
11641179 # dups and sum_ranks will be incremented each loop where
11651180 # the value / group remains the same, and should be reset
1166- # when either of those change
1167- # Used to calculate tiebreakers
1181+ # when either of those change. Used to calculate tiebreakers
11681182 dups += 1
11691183 sum_ranks += i - grp_start + 1
11701184
1185+ next_val_diff = at_end or (masked_vals[lexsort_indexer[i]]
1186+ != masked_vals[lexsort_indexer[i+ 1 ]])
1187+
1188+ # We'll need this check later anyway to determine group size, so just
1189+ # compute it here since shortcircuiting won't help
1190+ group_changed = at_end or (check_labels and
1191+ (labels[lexsort_indexer[i]]
1192+ != labels[lexsort_indexer[i+ 1 ]]))
1193+
11711194 # Update out only when there is a transition of values or labels.
11721195 # When a new value or group is encountered, go back #dups steps(
11731196 # the number of occurrence of current value) and assign the ranks
11741197 # based on the starting index of the current group (grp_start)
11751198 # and the current index
1176- if not at_end:
1177- next_val_diff = (masked_vals[lexsort_indexer[i]]
1178- != masked_vals[lexsort_indexer[i+ 1 ]])
1179- else :
1180- next_val_diff = True
1181-
1182- if (next_val_diff
1183- or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+ 1 ]])
1184- or (check_labels
1185- and (labels[lexsort_indexer[i]]
1186- != labels[lexsort_indexer[i+ 1 ]]))
1187- ):
1188- # if keep_na, check for missing values and assign back
1199+ if (next_val_diff or group_changed
1200+ or (mask[lexsort_indexer[i]] ^ mask[lexsort_indexer[i+ 1 ]])):
1201+
1202+ # If keep_na, check for missing values and assign back
11891203 # to the result where appropriate
11901204 if keep_na and mask[lexsort_indexer[i]]:
1205+ grp_na_count = dups
11911206 for j in range (i - dups + 1 , i + 1 ):
11921207 out[lexsort_indexer[j]] = NaN
1193- grp_na_count = dups
11941208 elif tiebreak == TIEBREAK_AVERAGE:
11951209 for j in range (i - dups + 1 , i + 1 ):
11961210 out[lexsort_indexer[j]] = sum_ranks / < float64_t> dups
@@ -1200,48 +1214,51 @@ def rank_1d(
12001214 elif tiebreak == TIEBREAK_MAX:
12011215 for j in range (i - dups + 1 , i + 1 ):
12021216 out[lexsort_indexer[j]] = i - grp_start + 1
1217+
1218+ # With n as the previous rank in the group and m as the number
1219+ # of duplicates in this stretch, if TIEBREAK_FIRST and ascending,
1220+ # then rankings should be n + 1, n + 2 ... n + m
12031221 elif tiebreak == TIEBREAK_FIRST:
12041222 for j in range (i - dups + 1 , i + 1 ):
1205- if ascending:
1206- out[lexsort_indexer[j]] = j + 1 - grp_start
1207- else :
1208- out[lexsort_indexer[j]] = \
1209- (2 * i - j - dups + 2 - grp_start)
1223+ out[lexsort_indexer[j]] = j + 1 - grp_start
1224+
1225+ # If TIEBREAK_FIRST and descending, the ranking should be
1226+ # n + m, n + (m - 1) ... n + 1. This is equivalent to
1227+ # (i - dups + 1) + (i - j + 1) - grp_start
1228+ elif tiebreak == TIEBREAK_FIRST_DESCENDING:
1229+ for j in range (i - dups + 1 , i + 1 ):
1230+ out[lexsort_indexer[j]] = 2 * i - j - dups + 2 - grp_start
12101231 elif tiebreak == TIEBREAK_DENSE:
12111232 for j in range (i - dups + 1 , i + 1 ):
12121233 out[lexsort_indexer[j]] = grp_vals_seen
12131234
1214- # look forward to the next value (using the sorting in
1235+ # Look forward to the next value (using the sorting in
12151236 # lexsort_indexer) if the value does not equal the current
1216- # value then we need to reset the dups and sum_ranks,
1217- # knowing that a new value is coming up. the conditional
1218- # also needs to handle nan equality and the end of iteration
1237+ # value then we need to reset the dups and sum_ranks, knowing
1238+ # that a new value is coming up. The conditional also needs
1239+ # to handle nan equality and the end of iteration
12191240 if next_val_diff or (mask[lexsort_indexer[i]]
12201241 ^ mask[lexsort_indexer[i+ 1 ]]):
12211242 dups = sum_ranks = 0
12221243 grp_vals_seen += 1
1223- grp_tie_count += 1
12241244
12251245 # Similar to the previous conditional, check now if we are
12261246 # moving to a new group. If so, keep track of the index where
12271247 # the new group occurs, so the tiebreaker calculations can
1228- # decrement that from their position. fill in the size of each
1229- # group encountered (used by pct calculations later). also be
1248+ # decrement that from their position. Fill in the size of each
1249+ # group encountered (used by pct calculations later). Also be
12301250 # sure to reset any of the items helping to calculate dups
1231- if at_end or (check_labels and
1232- (labels[lexsort_indexer[i]]
1233- != labels[lexsort_indexer[i+ 1 ]])):
1251+ if group_changed:
12341252 if tiebreak != TIEBREAK_DENSE:
12351253 for j in range (grp_start, i + 1 ):
12361254 grp_sizes[lexsort_indexer[j]] = \
12371255 (i - grp_start + 1 - grp_na_count)
12381256 else :
12391257 for j in range (grp_start, i + 1 ):
12401258 grp_sizes[lexsort_indexer[j]] = \
1241- (grp_tie_count - (grp_na_count > 0 ))
1259+ (grp_vals_seen - 1 - (grp_na_count > 0 ))
12421260 dups = sum_ranks = 0
12431261 grp_na_count = 0
1244- grp_tie_count = 0
12451262 grp_start = i + 1
12461263 grp_vals_seen = 1
12471264
0 commit comments