Skip to content

Commit 963c17b

Browse files
authored
formatting 6/n: metrics (#5722)
* yapf metrics * op
1 parent 069ae27 commit 963c17b

31 files changed

+375
-356
lines changed

.yapfignore

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ pytorch_lightning/core/*
2020
# TODO
2121
pytorch_lightning/loggers/*
2222

23-
24-
# TODO
25-
pytorch_lightning/metrics/*
26-
2723
# TODO
2824
pytorch_lightning/plugins/legacy/*
2925

pytorch_lightning/metrics/classification/auc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class AUC(Metric):
4242
Callback that performs the allgather operation on the metric state. When ``None``, DDP
4343
will be used to perform the allgather
4444
"""
45+
4546
def __init__(
4647
self,
4748
reorder: bool = False,

pytorch_lightning/metrics/classification/auroc.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class AUROC(Metric):
8686
tensor(0.7778)
8787
8888
"""
89+
8990
def __init__(
9091
self,
9192
num_classes: Optional[int] = None,
@@ -111,8 +112,9 @@ def __init__(
111112

112113
allowed_average = (None, 'macro', 'weighted')
113114
if self.average not in allowed_average:
114-
raise ValueError('Argument `average` expected to be one of the following:'
115-
f' {allowed_average} but got {average}')
115+
raise ValueError(
116+
f'Argument `average` expected to be one of the following: {allowed_average} but got {average}'
117+
)
116118

117119
if self.max_fpr is not None:
118120
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
@@ -146,8 +148,10 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
146148
self.target.append(target)
147149

148150
if self.mode is not None and self.mode != mode:
149-
raise ValueError('The mode of data (binary, multi-label, multi-class) should be constant, but changed'
150-
f' between batches from {self.mode} to {mode}')
151+
raise ValueError(
152+
'The mode of data (binary, multi-label, multi-class) should be constant, but changed'
153+
f' between batches from {self.mode} to {mode}'
154+
)
151155
self.mode = mode
152156

153157
def compute(self) -> torch.Tensor:
@@ -163,5 +167,5 @@ def compute(self) -> torch.Tensor:
163167
self.num_classes,
164168
self.pos_label,
165169
self.average,
166-
self.max_fpr
170+
self.max_fpr,
167171
)

pytorch_lightning/metrics/classification/average_precision.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class AveragePrecision(Metric):
6868
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
6969
7070
"""
71+
7172
def __init__(
7273
self,
7374
num_classes: Optional[int] = None,
@@ -102,10 +103,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
102103
target: Ground truth values
103104
"""
104105
preds, target, num_classes, pos_label = _average_precision_update(
105-
preds,
106-
target,
107-
self.num_classes,
108-
self.pos_label
106+
preds, target, self.num_classes, self.pos_label
109107
)
110108
self.preds.append(preds)
111109
self.target.append(target)

pytorch_lightning/metrics/classification/confusion_matrix.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class ConfusionMatrix(Metric):
7070
[1., 1.]])
7171
7272
"""
73+
7374
def __init__(
7475
self,
7576
num_classes: int,

pytorch_lightning/metrics/classification/f_beta.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def __init__(
8787
process_group: Optional[Any] = None,
8888
):
8989
super().__init__(
90-
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group,
90+
compute_on_step=compute_on_step,
91+
dist_sync_on_step=dist_sync_on_step,
92+
process_group=process_group,
9193
)
9294

9395
self.num_classes = num_classes
@@ -98,8 +100,10 @@ def __init__(
98100

99101
allowed_average = ("micro", "macro", "weighted", None)
100102
if self.average not in allowed_average:
101-
raise ValueError('Argument `average` expected to be one of the following:'
102-
f' {allowed_average} but got {self.average}')
103+
raise ValueError(
104+
'Argument `average` expected to be one of the following:'
105+
f' {allowed_average} but got {self.average}'
106+
)
103107

104108
self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
105109
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
@@ -125,8 +129,9 @@ def compute(self) -> torch.Tensor:
125129
"""
126130
Computes fbeta over state.
127131
"""
128-
return _fbeta_compute(self.true_positives, self.predicted_positives,
129-
self.actual_positives, self.beta, self.average)
132+
return _fbeta_compute(
133+
self.true_positives, self.predicted_positives, self.actual_positives, self.beta, self.average
134+
)
130135

131136

132137
class F1(FBeta):

pytorch_lightning/metrics/classification/iou.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,15 @@ class IoU(ConfusionMatrix):
7878
"""
7979

8080
def __init__(
81-
self,
82-
num_classes: int,
83-
ignore_index: Optional[int] = None,
84-
absent_score: float = 0.0,
85-
threshold: float = 0.5,
86-
reduction: str = 'elementwise_mean',
87-
compute_on_step: bool = True,
88-
dist_sync_on_step: bool = False,
89-
process_group: Optional[Any] = None,
81+
self,
82+
num_classes: int,
83+
ignore_index: Optional[int] = None,
84+
absent_score: float = 0.0,
85+
threshold: float = 0.5,
86+
reduction: str = 'elementwise_mean',
87+
compute_on_step: bool = True,
88+
dist_sync_on_step: bool = False,
89+
process_group: Optional[Any] = None,
9090
):
9191
super().__init__(
9292
num_classes=num_classes,

pytorch_lightning/metrics/classification/precision_recall_curve.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class PrecisionRecallCurve(Metric):
8282
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
8383
8484
"""
85+
8586
def __init__(
8687
self,
8788
num_classes: Optional[int] = None,
@@ -116,18 +117,17 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
116117
target: Ground truth values
117118
"""
118119
preds, target, num_classes, pos_label = _precision_recall_curve_update(
119-
preds,
120-
target,
121-
self.num_classes,
122-
self.pos_label
120+
preds, target, self.num_classes, self.pos_label
123121
)
124122
self.preds.append(preds)
125123
self.target.append(target)
126124
self.num_classes = num_classes
127125
self.pos_label = pos_label
128126

129-
def compute(self) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
130-
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
127+
def compute(
128+
self
129+
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
130+
List[torch.Tensor]]]:
131131
"""
132132
Compute the precision-recall curve
133133

pytorch_lightning/metrics/classification/roc.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class ROC(Metric):
8181
tensor([1.7500, 0.7500, 0.0500])]
8282
8383
"""
84+
8485
def __init__(
8586
self,
8687
num_classes: Optional[int] = None,
@@ -114,19 +115,16 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
114115
preds: Predictions from model
115116
target: Ground truth values
116117
"""
117-
preds, target, num_classes, pos_label = _roc_update(
118-
preds,
119-
target,
120-
self.num_classes,
121-
self.pos_label
122-
)
118+
preds, target, num_classes, pos_label = _roc_update(preds, target, self.num_classes, self.pos_label)
123119
self.preds.append(preds)
124120
self.target.append(target)
125121
self.num_classes = num_classes
126122
self.pos_label = pos_label
127123

128-
def compute(self) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
129-
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
124+
def compute(
125+
self
126+
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
127+
List[torch.Tensor]]]:
130128
"""
131129
Compute the receiver operating characteristic
132130

pytorch_lightning/metrics/classification/stat_scores.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def __init__(
165165
if reduce == "micro":
166166
zeros_shape = []
167167
elif reduce == "macro":
168-
zeros_shape = (num_classes,)
168+
zeros_shape = (num_classes, )
169169
default, reduce_fn = lambda: torch.zeros(zeros_shape, dtype=torch.long), "sum"
170170
else:
171171
default, reduce_fn = lambda: [], None

0 commit comments

Comments
 (0)