Skip to content

Commit 5d7388d

Browse files
frankiercarmocca
andauthored
Fix when _stable_1d_sort to work when n >= N (#6177)
* Fix when _stable_1d_sort to work when n >= N * Apply suggestions Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 59acf57 commit 5d7388d

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

pytorch_lightning/metrics/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def _stable_1d_sort(x: torch, N: int = 2049):
285285
n = x.numel()
286286
if N - n > 0:
287287
x_max = x.max()
288-
x_pad = torch.cat([x, (x_max + 1) * torch.ones(2049 - n, dtype=x.dtype, device=x.device)], 0)
289-
x_sort = x_pad.sort()
290-
return x_sort.values[:n], x_sort.indices[:n]
288+
x = torch.cat([x, (x_max + 1) * torch.ones(N - n, dtype=x.dtype, device=x.device)], 0)
289+
x_sort = x.sort()
290+
i = min(N, n)
291+
return x_sort.values[:i], x_sort.indices[:i]

tests/metrics/classification/test_auc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@ def test_auc_functional(self, x, y):
6161
])
6262
def test_auc(x, y, expected):
6363
# Test Area Under Curve (AUC) computation
64-
assert auc(torch.tensor(x), torch.tensor(y)) == expected
64+
assert auc(torch.tensor(x), torch.tensor(y), reorder=True) == expected

0 commit comments

Comments
 (0)