1818from torch import Tensor
1919from typing_extensions import Literal
2020
21- from torchmetrics .functional .segmentation .utils import _check_mixed_shape , _ignore_background
22- from torchmetrics .utilities .checks import _check_same_shape
21+ from torchmetrics .functional .segmentation .utils import _segmentation_inputs_format
2322from torchmetrics .utilities .compute import _safe_divide
2423
2524
@@ -51,8 +50,8 @@ def _mean_iou_validate_args(
5150 input_format : Literal ["one-hot" , "index" , "mixed" ] = "one-hot" ,
5251) -> None :
5352 """Validate the arguments of the metric."""
54- if input_format in ["index" , "mixed" ] and num_classes is None :
55- raise ValueError ("Argument `num_classes` must be provided when `input_format` is 'index' or 'mixed' ." )
53+ if input_format in ["index" ] and num_classes is None :
54+ raise ValueError ("Argument `num_classes` must be provided when `input_format` is 'index'." )
5655 if num_classes is not None and num_classes <= 0 :
5756 raise ValueError (
5857 f"Expected argument `num_classes` must be `None` or a positive integer, but got { num_classes } ."
@@ -76,33 +75,8 @@ def _mean_iou_update(
7675) -> tuple [Tensor , Tensor ]:
7776 """Update the intersection and union counts for the mean IoU computation."""
7877 preds , target = _mean_iou_reshape_args (preds , target , input_format )
79- if input_format == "mixed" :
80- _check_mixed_shape (preds , target )
81- else :
82- _check_same_shape (preds , target )
83-
84- if input_format == "index" :
85- if num_classes is None :
86- raise ValueError ("Argument `num_classes` must be provided when `input_format='index'`." )
87- preds = torch .nn .functional .one_hot (preds , num_classes = num_classes ).movedim (- 1 , 1 )
88- target = torch .nn .functional .one_hot (target , num_classes = num_classes ).movedim (- 1 , 1 )
89- elif input_format == "one-hot" and num_classes is None :
90- try :
91- num_classes = preds .shape [1 ]
92- except IndexError as err :
93- raise IndexError (f"Cannot determine `num_classes` from `preds` tensor: { preds } ." ) from err
94- if num_classes == 0 :
95- raise ValueError (f"Expected argument `num_classes` to be a positive integer, but got { num_classes } ." )
96- elif input_format == "mixed" :
97- if num_classes is None :
98- raise ValueError ("Argument `num_classes` must be provided when `input_format='mixed'`." )
99- if preds .dim () == (target .dim () + 1 ):
100- target = torch .nn .functional .one_hot (target , num_classes = num_classes ).movedim (- 1 , 1 )
101- elif (preds .dim () + 1 ) == target .dim ():
102- preds = torch .nn .functional .one_hot (preds , num_classes = num_classes ).movedim (- 1 , 1 )
103-
104- if not include_background :
105- preds , target = _ignore_background (preds , target )
78+
79+ preds , target = _segmentation_inputs_format (preds , target , include_background , num_classes , input_format )
10680
10781 reduce_axis = list (range (2 , preds .ndim ))
10882 intersection = torch .sum (preds & target , dim = reduce_axis )
@@ -136,7 +110,8 @@ def mean_iou(
136110 Args:
137111 preds: Predictions from model
138112 target: Ground truth values
139- num_classes: Number of classes (required when input_format="index", optional when input_format="one-hot")
113+ num_classes: Number of classes
114+ (required when input_format="index", optional when input_format="one-hot" or "mixed")
140115 include_background: Whether to include the background class in the computation
141116 per_class: Whether to compute the IoU for each class separately, else average over all classes
142117 input_format: What kind of input the function receives.
0 commit comments