Skip to content

Commit 3ee1b55

Browse files
deng-cySkafteNickiBorda
authored
Reformat iou [func] and add IoU class (#4704)
* added Iou * Create iou.py * Update iou.py * Update iou.py * Update CHANGELOG.md * Update metrics.rst * Update iou.py * Update iou.py * Update __init__.py * Update iou.py * Update iou.py * Update classification.py * Update classification.py * Update classification.py * Update __init__.py * Update __init__.py * Update iou.py * Update classification.py * Update metrics.rst * Update CHANGELOG.md * Update CHANGELOG.md * add iou * add test * add test * removed iou * add iou * add iou test * add float * reformat test_iou * removed test_iou * updated format * updated format * Update CHANGELOG.md * updated format * Update metrics.rst * Apply suggestions from code review merge suggestions Co-authored-by: Nicki Skafte <[email protected]> * added equations * reformat init * change format * change format * deprecate iou and test for this * fix changelog * delete iou test in test_classification * format change * format change * format * format * format * delete white space * delete white space * fix tests * Apply suggestions from code review Co-authored-by: Jirka Borovec <[email protected]> * Apply suggestions from code review Co-authored-by: Jirka Borovec <[email protected]> * better deprecation * fix docs * Apply suggestions from code review * fix todo Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 06668c0 commit 3ee1b55

File tree

11 files changed

+478
-157
lines changed

11 files changed

+478
-157
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3939
- Added `.clone()` method to metrics ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318))
4040

4141

42+
- Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))
43+
44+
4245
### Changed
4346

4447
- Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218))
@@ -47,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4750
- `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
4851

4952

53+
- Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))
54+
55+
5056
### Deprecated
5157

5258
- `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

docs/source/metrics.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,12 @@ FBeta
465465

466466
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
467467
:noindex:
468+
469+
IoU
470+
~~~
471+
472+
.. autoclass:: pytorch_lightning.metrics.classification.IoU
473+
:noindex:
468474

469475
Hamming Distance
470476
~~~~~~~~~~~~~~~~
@@ -577,7 +583,7 @@ hamming_distance [func]
577583
iou [func]
578584
~~~~~~~~~~
579585

580-
.. autofunction:: pytorch_lightning.metrics.functional.classification.iou
586+
.. autofunction:: pytorch_lightning.metrics.functional.iou
581587
:noindex:
582588

583589

pytorch_lightning/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_lightning.metrics.classification import ( # noqa: F401
1717
Accuracy,
1818
HammingDistance,
19+
IoU,
1920
Precision,
2021
Recall,
2122
ConfusionMatrix,

pytorch_lightning/metrics/classification/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
2121
from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401
2222
from pytorch_lightning.metrics.classification.stat_scores import StatScores # noqa: F401
23+
from pytorch_lightning.metrics.classification.iou import IoU # noqa: F401
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Optional
15+
16+
import torch
17+
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix
18+
from pytorch_lightning.metrics.functional.iou import _iou_from_confmat
19+
20+
21+
class IoU(ConfusionMatrix):
22+
r"""
23+
Computes `Intersection over union, or Jaccard index calculation <https://en.wikipedia.org/wiki/Jaccard_index>`_:
24+
25+
.. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|}
26+
27+
Where: :math:`A` and :math:`B` are both tensors of the same size, containing integer class values.
28+
They may be subject to conversion from input data (see description below). Note that it is different from box IoU.
29+
30+
Works with binary, multiclass and multi-label data.
31+
Accepts logits from a model output or integer class values in prediction.
32+
Works with multi-dimensional preds and target.
33+
34+
Forward accepts
35+
36+
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
37+
- ``target`` (long tensor): ``(N, ...)``
38+
39+
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
40+
This is the case for binary and multi-label logits.
41+
42+
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
43+
44+
Args:
45+
num_classes: Number of classes in the dataset.
46+
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
47+
to the returned score, regardless of reduction method. Has no effect if given an int that is not in the
48+
range [0, num_classes-1]. By default, no index is ignored, and all classes are used.
49+
absent_score: score to use for an individual class, if no instances of the class index were present in
50+
`pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes,
51+
[0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`.
52+
threshold:
53+
Threshold value for binary or multi-label logits.
54+
reduction: a method to reduce metric score over labels.
55+
56+
- ``'elementwise_mean'``: takes the mean (default)
57+
- ``'sum'``: takes the sum
58+
- ``'none'``: no reduction will be applied
59+
60+
compute_on_step:
61+
Forward only calls ``update()`` and return None if this is set to False.
62+
dist_sync_on_step:
63+
Synchronize metric state across processes at each ``forward()``
64+
before returning the value at the step.
65+
process_group:
66+
Specify the process group on which synchronization is called. default: None (which selects the entire world)
67+
68+
Example:
69+
>>> from pytorch_lightning.metrics import IoU
70+
>>> target = torch.randint(0, 2, (10, 25, 25))
71+
>>> pred = torch.tensor(target)
72+
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
73+
>>> iou = IoU(num_classes=2)
74+
>>> iou(pred, target)
75+
tensor(0.9660)
76+
77+
"""
78+
79+
def __init__(
80+
self,
81+
num_classes: int,
82+
ignore_index: Optional[int] = None,
83+
absent_score: float = 0.0,
84+
threshold: float = 0.5,
85+
reduction: str = 'elementwise_mean',
86+
compute_on_step: bool = True,
87+
dist_sync_on_step: bool = False,
88+
process_group: Optional[Any] = None,
89+
):
90+
super().__init__(
91+
num_classes=num_classes,
92+
normalize=None,
93+
threshold=threshold,
94+
compute_on_step=compute_on_step,
95+
dist_sync_on_step=dist_sync_on_step,
96+
process_group=process_group,
97+
)
98+
self.reduction = reduction
99+
self.ignore_index = ignore_index
100+
self.absent_score = absent_score
101+
102+
def compute(self) -> torch.Tensor:
103+
"""
104+
Computes intersection over union (IoU)
105+
"""
106+
return _iou_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction)

pytorch_lightning/metrics/functional/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
auroc,
1818
dice_score,
1919
get_num_classes,
20-
iou,
2120
multiclass_auroc,
2221
precision,
2322
precision_recall,
@@ -32,6 +31,8 @@
3231
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
3332
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
3433
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
34+
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401
35+
from pytorch_lightning.metrics.functional.iou import iou # noqa: F401
3536
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401
3637
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401
3738
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401
@@ -43,4 +44,3 @@
4344
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401
4445
from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401
4546
from pytorch_lightning.metrics.functional.stat_scores import stat_scores # noqa: F401
46-
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401

pytorch_lightning/metrics/functional/classification.py

Lines changed: 24 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from distutils.version import LooseVersion
1415
from functools import wraps
1516
from typing import Callable, Optional, Sequence, Tuple
1617

1718
import torch
18-
from distutils.version import LooseVersion
19-
2019
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
20+
from pytorch_lightning.metrics.functional.iou import iou as __iou
2121
from pytorch_lightning.metrics.functional.precision_recall_curve import (
2222
_binary_clf_curve,
2323
precision_recall_curve as __prc
@@ -84,7 +84,7 @@ def get_num_classes(
8484
" `from pytorch_lightning.metrics.utils import get_num_classes`."
8585
" It will be removed in v1.3.0", DeprecationWarning
8686
)
87-
return __gnc(pred,target, num_classes)
87+
return __gnc(pred, target, num_classes)
8888

8989

9090
def stat_scores(
@@ -162,8 +162,8 @@ def stat_scores_multiple_classes(
162162
raise ValueError("reduction type %s not supported" % reduction)
163163

164164
if reduction == 'none':
165-
pred = pred.view((-1, )).long()
166-
target = target.view((-1, )).long()
165+
pred = pred.view((-1,)).long()
166+
target = target.view((-1,)).long()
167167

168168
tps = torch.zeros((num_classes + 1,), device=pred.device)
169169
fps = torch.zeros((num_classes + 1,), device=pred.device)
@@ -687,6 +687,7 @@ def dice_score(
687687
return reduce(scores, reduction=reduction)
688688

689689

690+
# todo: remove in 1.4
690691
def iou(
691692
pred: torch.Tensor,
692693
target: torch.Tensor,
@@ -698,6 +699,10 @@ def iou(
698699
"""
699700
Intersection over union, or Jaccard index calculation.
700701
702+
.. warning :: Deprecated in favor of
703+
:func:`~pytorch_lightning.metrics.functional.iou.iou`. Will be removed in
704+
v1.4.0.
705+
701706
Args:
702707
pred: Tensor containing integer predictions, with shape [N, d1, d2, ...]
703708
target: Tensor containing integer targets, with shape [N, d1, d2, ...]
@@ -729,48 +734,20 @@ def iou(
729734
tensor(0.9660)
730735
731736
"""
732-
if pred.size() != target.size():
733-
raise ValueError(f"'pred' shape ({pred.size()}) must equal 'target' shape ({target.size()})")
734-
735-
if not torch.allclose(pred.float(), pred.int().float()):
736-
raise ValueError("'pred' must contain integer targets.")
737-
738-
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
739-
740-
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes)
741-
742-
scores = torch.zeros(num_classes, device=pred.device, dtype=torch.float32)
743-
744-
for class_idx in range(num_classes):
745-
if class_idx == ignore_index:
746-
continue
747-
748-
tp = tps[class_idx]
749-
fp = fps[class_idx]
750-
fn = fns[class_idx]
751-
sup = sups[class_idx]
752-
753-
# If this class is absent in the target (no support) AND absent in the pred (no true or false
754-
# positives), then use the absent_score for this class.
755-
if sup + tp + fp == 0:
756-
scores[class_idx] = absent_score
757-
continue
758-
759-
denom = tp + fp + fn
760-
# Note that we do not need to worry about division-by-zero here since we know (sup + tp + fp != 0) from above,
761-
# which means ((tp+fn) + tp + fp != 0), which means (2tp + fp + fn != 0). Since all vars are non-negative, we
762-
# can conclude (tp + fp + fn > 0), meaning the denominator is non-zero for each class.
763-
score = tp.to(torch.float) / denom
764-
scores[class_idx] = score
765-
766-
# Remove the ignored class index from the scores.
767-
if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes:
768-
scores = torch.cat([
769-
scores[:ignore_index],
770-
scores[ignore_index + 1:],
771-
])
772-
773-
return reduce(scores, reduction=reduction)
737+
rank_zero_warn(
738+
"This `iou` was deprecated in v1.2.0 in favor of"
739+
" `from pytorch_lightning.metrics.functional.iou import iou`."
740+
" It will be removed in v1.4.0", DeprecationWarning
741+
)
742+
return __iou(
743+
pred=pred,
744+
target=target,
745+
ignore_index=ignore_index,
746+
absent_score=absent_score,
747+
threshold=0.5,
748+
num_classes=num_classes,
749+
reduction=reduction
750+
)
774751

775752

776753
# todo: remove in 1.3

0 commit comments

Comments
 (0)