Skip to content

Commit ace9941

Browse files
simonreisepre-commit-ci[bot]Borda
committed
Float input support for segmentation metrics (#3198)
* Add logit support to segmentation metrics * Improved num_classes logic * Fix for mypy * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]> (cherry picked from commit 357c349)
1 parent b6c0166 commit ace9941

File tree

12 files changed

+134
-102
lines changed

12 files changed

+134
-102
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2121
-
2222

2323

24+
- Float input support for segmentation metrics ([#3198](https://github.com/Lightning-AI/torchmetrics/pull/3198))
25+
26+
2427
### Deprecated
2528

2629
-

src/torchmetrics/functional/segmentation/dice.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
from torch import Tensor
1818
from typing_extensions import Literal
1919

20-
from torchmetrics.functional.segmentation.utils import _check_mixed_shape, _ignore_background
20+
from torchmetrics.functional.segmentation.utils import _segmentation_inputs_format
2121
from torchmetrics.utilities import rank_zero_warn
22-
from torchmetrics.utilities.checks import _check_same_shape
2322
from torchmetrics.utilities.compute import _safe_divide
2423

2524

@@ -56,25 +55,7 @@ def _dice_score_update(
5655
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
5756
) -> tuple[Tensor, Tensor, Tensor]:
5857
"""Update the state with the current prediction and target."""
59-
if input_format == "mixed":
60-
_check_mixed_shape(preds, target)
61-
else:
62-
_check_same_shape(preds, target)
63-
64-
if input_format == "index":
65-
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
66-
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
67-
elif input_format == "mixed":
68-
if preds.dim() == (target.dim() + 1):
69-
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
70-
elif (preds.dim() + 1) == target.dim():
71-
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
72-
73-
if preds.ndim < 3:
74-
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")
75-
76-
if not include_background:
77-
preds, target = _ignore_background(preds, target)
58+
preds, target = _segmentation_inputs_format(preds, target, include_background, num_classes, input_format)
7859

7960
reduce_axis = list(range(2, target.ndim))
8061
intersection = torch.sum(preds * target, dim=reduce_axis)

src/torchmetrics/functional/segmentation/generalized_dice.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from torch import Tensor
1818
from typing_extensions import Literal
1919

20-
from torchmetrics.functional.segmentation.utils import _check_mixed_shape, _ignore_background
21-
from torchmetrics.utilities.checks import _check_same_shape
20+
from torchmetrics.functional.segmentation.utils import _segmentation_inputs_format
2221
from torchmetrics.utilities.compute import _safe_divide
2322

2423

@@ -55,25 +54,7 @@ def _generalized_dice_update(
5554
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
5655
) -> Tuple[Tensor, Tensor]:
5756
"""Update the state with the current prediction and target."""
58-
if input_format == "mixed":
59-
_check_mixed_shape(preds, target)
60-
else:
61-
_check_same_shape(preds, target)
62-
63-
if input_format == "index":
64-
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
65-
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
66-
elif input_format == "mixed":
67-
if preds.dim() == (target.dim() + 1):
68-
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
69-
elif (preds.dim() + 1) == target.dim():
70-
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
71-
72-
if preds.ndim < 3:
73-
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")
74-
75-
if not include_background:
76-
preds, target = _ignore_background(preds, target)
57+
preds, target = _segmentation_inputs_format(preds, target, include_background, num_classes, input_format)
7758

7859
reduce_axis = list(range(2, target.ndim))
7960
intersection = torch.sum(preds * target, dim=reduce_axis)

src/torchmetrics/functional/segmentation/hausdorff_distance.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@
1717
import torch
1818
from torch import Tensor
1919

20-
from torchmetrics.functional.segmentation.utils import (
21-
_check_mixed_shape,
22-
_ignore_background,
23-
edge_surface_distance,
24-
)
25-
from torchmetrics.utilities.checks import _check_same_shape
20+
from torchmetrics.functional.segmentation.utils import _segmentation_inputs_format, edge_surface_distance
2621

2722

2823
def _hausdorff_distance_validate_args(
@@ -93,22 +88,8 @@ def hausdorff_distance(
9388
9489
"""
9590
_hausdorff_distance_validate_args(num_classes, include_background, distance_metric, spacing, directed, input_format)
96-
if input_format == "mixed":
97-
_check_mixed_shape(preds, target)
98-
else:
99-
_check_same_shape(preds, target)
10091

101-
if input_format == "index":
102-
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
103-
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
104-
elif input_format == "mixed":
105-
if preds.dim() == (target.dim() + 1):
106-
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
107-
elif (preds.dim() + 1) == target.dim():
108-
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
109-
110-
if not include_background:
111-
preds, target = _ignore_background(preds, target)
92+
preds, target = _segmentation_inputs_format(preds, target, include_background, num_classes, input_format)
11293

11394
distances = torch.zeros(preds.shape[0], preds.shape[1], device=preds.device)
11495

src/torchmetrics/functional/segmentation/mean_iou.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from torch import Tensor
1919
from typing_extensions import Literal
2020

21-
from torchmetrics.functional.segmentation.utils import _check_mixed_shape, _ignore_background
22-
from torchmetrics.utilities.checks import _check_same_shape
21+
from torchmetrics.functional.segmentation.utils import _segmentation_inputs_format
2322
from torchmetrics.utilities.compute import _safe_divide
2423

2524

@@ -51,8 +50,8 @@ def _mean_iou_validate_args(
5150
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
5251
) -> None:
5352
"""Validate the arguments of the metric."""
54-
if input_format in ["index", "mixed"] and num_classes is None:
55-
raise ValueError("Argument `num_classes` must be provided when `input_format` is 'index' or 'mixed'.")
53+
if input_format in ["index"] and num_classes is None:
54+
raise ValueError("Argument `num_classes` must be provided when `input_format` is 'index'.")
5655
if num_classes is not None and num_classes <= 0:
5756
raise ValueError(
5857
f"Expected argument `num_classes` must be `None` or a positive integer, but got {num_classes}."
@@ -76,33 +75,8 @@ def _mean_iou_update(
7675
) -> tuple[Tensor, Tensor]:
7776
"""Update the intersection and union counts for the mean IoU computation."""
7877
preds, target = _mean_iou_reshape_args(preds, target, input_format)
79-
if input_format == "mixed":
80-
_check_mixed_shape(preds, target)
81-
else:
82-
_check_same_shape(preds, target)
83-
84-
if input_format == "index":
85-
if num_classes is None:
86-
raise ValueError("Argument `num_classes` must be provided when `input_format='index'`.")
87-
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
88-
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
89-
elif input_format == "one-hot" and num_classes is None:
90-
try:
91-
num_classes = preds.shape[1]
92-
except IndexError as err:
93-
raise IndexError(f"Cannot determine `num_classes` from `preds` tensor: {preds}.") from err
94-
if num_classes == 0:
95-
raise ValueError(f"Expected argument `num_classes` to be a positive integer, but got {num_classes}.")
96-
elif input_format == "mixed":
97-
if num_classes is None:
98-
raise ValueError("Argument `num_classes` must be provided when `input_format='mixed'`.")
99-
if preds.dim() == (target.dim() + 1):
100-
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
101-
elif (preds.dim() + 1) == target.dim():
102-
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
103-
104-
if not include_background:
105-
preds, target = _ignore_background(preds, target)
78+
79+
preds, target = _segmentation_inputs_format(preds, target, include_background, num_classes, input_format)
10680

10781
reduce_axis = list(range(2, preds.ndim))
10882
intersection = torch.sum(preds & target, dim=reduce_axis)
@@ -136,7 +110,8 @@ def mean_iou(
136110
Args:
137111
preds: Predictions from model
138112
target: Ground truth values
139-
num_classes: Number of classes (required when input_format="index", optional when input_format="one-hot")
113+
num_classes: Number of classes
114+
(required when input_format="index", optional when input_format="one-hot" or "mixed")
140115
include_background: Whether to include the background class in the computation
141116
per_class: Whether to compute the IoU for each class separately, else average over all classes
142117
input_format: What kind of input the function receives.

src/torchmetrics/functional/segmentation/utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,69 @@ def _check_mixed_shape(preds: Tensor, target: Tensor) -> None:
4949
)
5050

5151

52+
def _segmentation_inputs_format(
53+
preds: Tensor,
54+
target: Tensor,
55+
include_background: bool,
56+
num_classes: Optional[int] = None,
57+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
58+
) -> tuple[Tensor, Tensor]:
59+
"""Check and format inputs to the one-hot encodings."""
60+
if input_format == "mixed":
61+
_check_mixed_shape(preds, target)
62+
else:
63+
_check_same_shape(preds, target)
64+
65+
if input_format == "index":
66+
if num_classes is None:
67+
raise ValueError("Argument `num_classes` must be provided when `input_format='index'`.")
68+
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
69+
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
70+
elif input_format == "one-hot":
71+
if num_classes is None:
72+
num_classes = _get_num_classes(preds)
73+
preds = _format_logits(preds, num_classes)
74+
target = _format_logits(target, num_classes)
75+
elif input_format == "mixed":
76+
if preds.dim() == (target.dim() + 1):
77+
if num_classes is None:
78+
num_classes = _get_num_classes(preds)
79+
preds = _format_logits(preds, num_classes)
80+
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
81+
elif (preds.dim() + 1) == target.dim():
82+
if num_classes is None:
83+
num_classes = _get_num_classes(target)
84+
target = _format_logits(target, num_classes)
85+
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
86+
87+
if preds.ndim < 3:
88+
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")
89+
90+
if not include_background:
91+
preds, target = _ignore_background(preds, target)
92+
93+
return preds, target
94+
95+
96+
def _format_logits(tensor: Tensor, num_classes: int) -> Tensor:
97+
"""Transform logits or probabilities into integer one-hot encodings."""
98+
if torch.is_floating_point(tensor):
99+
tensor = tensor.argmax(dim=1)
100+
tensor = torch.nn.functional.one_hot(tensor, num_classes=num_classes).movedim(-1, 1)
101+
return tensor
102+
103+
104+
def _get_num_classes(tensor: Tensor) -> int:
105+
"""Get num classes from a tensor if it is not set."""
106+
try:
107+
num_classes = tensor.shape[1]
108+
except IndexError as err:
109+
raise IndexError(f"Cannot determine `num_classes` from tensor: {tensor}.") from err
110+
if num_classes == 0:
111+
raise ValueError(f"Expected argument `num_classes` to be a positive integer, but got {num_classes}.")
112+
return num_classes
113+
114+
52115
def check_if_binarized(x: Tensor) -> None:
53116
"""Check if tensor is binarized.
54117

src/torchmetrics/segmentation/mean_iou.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ class MeanIoU(Metric):
5353
set to ``False``, the output will be a scalar tensor.
5454
5555
Args:
56-
num_classes: The number of classes in the segmentation problem. Required when input_format="index" or "mixed",
57-
optional when input_format="one-hot".
56+
num_classes: The number of classes in the segmentation problem. Required when input_format="index",
57+
optional when input_format="one-hot" or "mixed".
5858
include_background: Whether to include the background class in the computation
5959
per_class: Whether to compute the IoU for each class separately. If set to ``False``, the metric will
6060
compute the mean IoU over all classes.
@@ -67,7 +67,7 @@ class MeanIoU(Metric):
6767
ValueError:
6868
If ``num_classes`` is not ``None`` or a positive integer
6969
ValueError:
70-
If ``num_classes`` is not provided when ``input_format`` is ``"index"`` or ``"mixed"``
70+
If ``num_classes`` is not provided when ``input_format`` is ``"index"``
7171
ValueError:
7272
If ``include_background`` is not a boolean
7373
ValueError:
@@ -132,7 +132,20 @@ def update(self, preds: Tensor, target: Tensor) -> None:
132132
"""Update the state with the new data."""
133133
if not self._is_initialized:
134134
try:
135-
self.num_classes = preds.shape[1]
135+
if self.input_format == "one-hot":
136+
self.num_classes = preds.shape[1]
137+
elif self.input_format == "mixed":
138+
if preds.dim() == (target.dim() + 1):
139+
self.num_classes = preds.shape[1]
140+
elif (preds.dim() + 1) == target.dim():
141+
self.num_classes = target.shape[1]
142+
else:
143+
raise ValueError(
144+
"Predictions and targets are expected to have the same shape,",
145+
f"got {preds.shape} and {target.shape}.",
146+
)
147+
else:
148+
raise ValueError("Argument `num_classes` must be provided when `input_format` is 'index'.")
136149
except IndexError as err:
137150
raise IndexError(f"Cannot determine `num_classes` from `preds` tensor: {preds}.") from err
138151

tests/unittests/segmentation/inputs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,8 @@
4848
preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)),
4949
target=to_one_hot(torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32))),
5050
)
51+
52+
_mixed_logits_input = _Input(
53+
preds=(torch.rand((NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)) * 12 - 6),
54+
target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)),
55+
)

tests/unittests/segmentation/test_dice.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_index_input_2,
2929
_mixed_input_1,
3030
_mixed_input_2,
31+
_mixed_logits_input,
3132
_one_hot_input_1,
3233
_one_hot_input_2,
3334
)
@@ -76,6 +77,7 @@ def _reference_dice_score(
7677
(_index_input_2.preds, _index_input_2.target, "index"),
7778
(_mixed_input_1.preds, _mixed_input_1.target, "mixed"),
7879
(_mixed_input_2.preds, _mixed_input_2.target, "mixed"),
80+
(_mixed_logits_input.preds, _mixed_logits_input.target, "mixed"),
7981
],
8082
)
8183
@pytest.mark.parametrize("include_background", [True, False])

tests/unittests/segmentation/test_generalized_dice_score.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_index_input_2,
3030
_mixed_input_1,
3131
_mixed_input_2,
32+
_mixed_logits_input,
3233
_one_hot_input_1,
3334
_one_hot_input_2,
3435
)
@@ -49,8 +50,14 @@ def _reference_generalized_dice(
4950
target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1)
5051
elif input_format == "mixed":
5152
if preds.dim() == (target.dim() + 1):
53+
if torch.is_floating_point(preds):
54+
preds = preds.argmax(dim=1)
55+
preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1)
5256
target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1)
5357
elif (preds.dim() + 1) == target.dim():
58+
if torch.is_floating_point(target):
59+
target = target.argmax(dim=1)
60+
target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1)
5461
preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1)
5562
monai_extra_arg = {"sum_over_classes": True} if RequirementCache("monai>=1.4.0") else {}
5663
val = compute_generalized_dice(preds, target, include_background=include_background, **monai_extra_arg)
@@ -68,6 +75,7 @@ def _reference_generalized_dice(
6875
(_index_input_2.preds, _index_input_2.target, "index"),
6976
(_mixed_input_1.preds, _mixed_input_1.target, "mixed"),
7077
(_mixed_input_2.preds, _mixed_input_2.target, "mixed"),
78+
(_mixed_logits_input.preds, _mixed_logits_input.target, "mixed"),
7179
],
7280
)
7381
@pytest.mark.parametrize("include_background", [True, False])

0 commit comments

Comments
 (0)