Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .yapfignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ pytorch_lightning/core/*
# TODO
pytorch_lightning/loggers/*


# TODO
pytorch_lightning/metrics/*

# TODO
pytorch_lightning/plugins/legacy/*

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class AUC(Metric):
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
"""

def __init__(
self,
reorder: bool = False,
Expand Down
14 changes: 9 additions & 5 deletions pytorch_lightning/metrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class AUROC(Metric):
tensor(0.7778)

"""

def __init__(
self,
num_classes: Optional[int] = None,
Expand All @@ -111,8 +112,9 @@ def __init__(

allowed_average = (None, 'macro', 'weighted')
if self.average not in allowed_average:
raise ValueError('Argument `average` expected to be one of the following:'
f' {allowed_average} but got {average}')
raise ValueError(
f'Argument `average` expected to be one of the following: {allowed_average} but got {average}'
)

if self.max_fpr is not None:
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
Expand Down Expand Up @@ -146,8 +148,10 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
self.target.append(target)

if self.mode is not None and self.mode != mode:
raise ValueError('The mode of data (binary, multi-label, multi-class) should be constant, but changed'
f' between batches from {self.mode} to {mode}')
raise ValueError(
'The mode of data (binary, multi-label, multi-class) should be constant, but changed'
f' between batches from {self.mode} to {mode}'
)
self.mode = mode

def compute(self) -> torch.Tensor:
Expand All @@ -163,5 +167,5 @@ def compute(self) -> torch.Tensor:
self.num_classes,
self.pos_label,
self.average,
self.max_fpr
self.max_fpr,
)
6 changes: 2 additions & 4 deletions pytorch_lightning/metrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class AveragePrecision(Metric):
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]

"""

def __init__(
self,
num_classes: Optional[int] = None,
Expand Down Expand Up @@ -102,10 +103,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
target: Ground truth values
"""
preds, target, num_classes, pos_label = _average_precision_update(
preds,
target,
self.num_classes,
self.pos_label
preds, target, self.num_classes, self.pos_label
)
self.preds.append(preds)
self.target.append(target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class ConfusionMatrix(Metric):
[1., 1.]])

"""

def __init__(
self,
num_classes: int,
Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/metrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def __init__(
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)

self.num_classes = num_classes
Expand All @@ -98,8 +100,10 @@ def __init__(

allowed_average = ("micro", "macro", "weighted", None)
if self.average not in allowed_average:
raise ValueError('Argument `average` expected to be one of the following:'
f' {allowed_average} but got {self.average}')
raise ValueError(
'Argument `average` expected to be one of the following:'
f' {allowed_average} but got {self.average}'
)

self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
Expand All @@ -125,8 +129,9 @@ def compute(self) -> torch.Tensor:
"""
Computes fbeta over state.
"""
return _fbeta_compute(self.true_positives, self.predicted_positives,
self.actual_positives, self.beta, self.average)
return _fbeta_compute(
self.true_positives, self.predicted_positives, self.actual_positives, self.beta, self.average
)


class F1(FBeta):
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/metrics/classification/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ class IoU(ConfusionMatrix):
"""

def __init__(
self,
num_classes: int,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
reduction: str = 'elementwise_mean',
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
self,
num_classes: int,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
reduction: str = 'elementwise_mean',
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
num_classes=num_classes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class PrecisionRecallCurve(Metric):
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]

"""

def __init__(
self,
num_classes: Optional[int] = None,
Expand Down Expand Up @@ -116,18 +117,17 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
target: Ground truth values
"""
preds, target, num_classes, pos_label = _precision_recall_curve_update(
preds,
target,
self.num_classes,
self.pos_label
preds, target, self.num_classes, self.pos_label
)
self.preds.append(preds)
self.target.append(target)
self.num_classes = num_classes
self.pos_label = pos_label

def compute(self) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
def compute(
self
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor]]]:
"""
Compute the precision-recall curve

Expand Down
14 changes: 6 additions & 8 deletions pytorch_lightning/metrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class ROC(Metric):
tensor([1.7500, 0.7500, 0.0500])]

"""

def __init__(
self,
num_classes: Optional[int] = None,
Expand Down Expand Up @@ -114,19 +115,16 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
preds, target, num_classes, pos_label = _roc_update(
preds,
target,
self.num_classes,
self.pos_label
)
preds, target, num_classes, pos_label = _roc_update(preds, target, self.num_classes, self.pos_label)
self.preds.append(preds)
self.target.append(target)
self.num_classes = num_classes
self.pos_label = pos_label

def compute(self) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
def compute(
self
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor]]]:
"""
Compute the receiver operating characteristic

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
if reduce == "micro":
zeros_shape = []
elif reduce == "macro":
zeros_shape = (num_classes,)
zeros_shape = (num_classes, )
default, reduce_fn = lambda: torch.zeros(zeros_shape, dtype=torch.long), "sum"
else:
default, reduce_fn = lambda: [], None
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/metrics/functional/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@

def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.ndim > 1 or y.ndim > 1:
raise ValueError(f'Expected both `x` and `y` tensor to be 1d, but got'
f' tensors with dimention {x.ndim} and {y.ndim}')
raise ValueError(
f'Expected both `x` and `y` tensor to be 1d, but got'
f' tensors with dimention {x.ndim} and {y.ndim}'
)
if x.numel() != y.numel():
raise ValueError(f'Expected the same number of elements in `x` and `y`'
f' tensor but received {x.numel()} and {y.numel()}')
raise ValueError(
f'Expected the same number of elements in `x` and `y`'
f' tensor but received {x.numel()} and {y.numel()}'
)
return x, y


Expand Down
58 changes: 33 additions & 25 deletions pytorch_lightning/metrics/functional/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens


def _auroc_compute(
preds: torch.Tensor,
target: torch.Tensor,
mode: str,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = 'macro',
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
preds: torch.Tensor,
target: torch.Tensor,
mode: str,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = 'macro',
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
) -> torch.Tensor:
# binary mode override num_classes
if mode == 'binary':
Expand All @@ -65,20 +65,26 @@ def _auroc_compute(
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")

if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
raise RuntimeError("`max_fpr` argument requires `torch.bucketize` which"
" is not available below PyTorch version 1.6")
raise RuntimeError(
"`max_fpr` argument requires `torch.bucketize` which"
" is not available below PyTorch version 1.6"
)

# max_fpr parameter is only support for binary
if mode != 'binary':
raise ValueError(f"Partial AUC computation not available in "
f"multilabel/multiclass setting, 'max_fpr' must be"
f" set to `None`, received `{max_fpr}`.")
raise ValueError(
f"Partial AUC computation not available in"
f" multilabel/multiclass setting, 'max_fpr' must be"
f" set to `None`, received `{max_fpr}`."
)

# calculate fpr, tpr
if mode == 'multi-label':
# for multilabel we iteratively evaluate roc in a binary fashion
output = [roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights)
for i in range(num_classes)]
output = [
roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights)
for i in range(num_classes)
]
fpr = [o[0] for o in output]
tpr = [o[1] for o in output]
else:
Expand All @@ -103,8 +109,10 @@ def _auroc_compute(
return torch.sum(torch.stack(auc_scores) * support / support.sum())

allowed_average = [e.value for e in AverageMethods]
raise ValueError(f"Argument `average` expected to be one of the following:"
f" {allowed_average} but got {average}")
raise ValueError(
f"Argument `average` expected to be one of the following:"
f" {allowed_average} but got {average}"
)

return auc(fpr, tpr)

Expand All @@ -121,19 +129,19 @@ def _auroc_compute(

# McClish correction: standardize result to be 0.5 if non-discriminant
# and 1 if maximal
min_area = 0.5 * max_fpr ** 2
min_area = 0.5 * max_fpr**2
max_area = max_fpr
return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))


def auroc(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = 'macro',
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
preds: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = 'macro',
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
) -> torch.Tensor:
""" Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC)
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Further_interpretations>`_
Expand Down
28 changes: 14 additions & 14 deletions pytorch_lightning/metrics/functional/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@


def _average_precision_update(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
preds: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
return _precision_recall_curve_update(preds, target, num_classes, pos_label)


def _average_precision_compute(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
pos_label: int,
sample_weights: Optional[Sequence] = None
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
pos_label: int,
sample_weights: Optional[Sequence] = None
) -> Union[List[torch.Tensor], torch.Tensor]:
precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
# Return the step function integral
Expand All @@ -51,11 +51,11 @@ def _average_precision_compute(


def average_precision(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
preds: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""
Computes the average precision score.
Expand Down
Loading