diff --git a/.yapfignore b/.yapfignore index b9f7d5cd47319..9b42f233daa43 100644 --- a/.yapfignore +++ b/.yapfignore @@ -20,10 +20,6 @@ pytorch_lightning/core/* # TODO pytorch_lightning/loggers/* - -# TODO -pytorch_lightning/metrics/* - # TODO pytorch_lightning/plugins/legacy/* diff --git a/pytorch_lightning/metrics/classification/auc.py b/pytorch_lightning/metrics/classification/auc.py index d0a737c9f7166..6c5a29173d20a 100644 --- a/pytorch_lightning/metrics/classification/auc.py +++ b/pytorch_lightning/metrics/classification/auc.py @@ -42,6 +42,7 @@ class AUC(Metric): Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather """ + def __init__( self, reorder: bool = False, diff --git a/pytorch_lightning/metrics/classification/auroc.py b/pytorch_lightning/metrics/classification/auroc.py index f6c69b8075ca8..a755e2bbb89cd 100644 --- a/pytorch_lightning/metrics/classification/auroc.py +++ b/pytorch_lightning/metrics/classification/auroc.py @@ -86,6 +86,7 @@ class AUROC(Metric): tensor(0.7778) """ + def __init__( self, num_classes: Optional[int] = None, @@ -111,8 +112,9 @@ def __init__( allowed_average = (None, 'macro', 'weighted') if self.average not in allowed_average: - raise ValueError('Argument `average` expected to be one of the following:' - f' {allowed_average} but got {average}') + raise ValueError( + f'Argument `average` expected to be one of the following: {allowed_average} but got {average}' + ) if self.max_fpr is not None: if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1): @@ -146,8 +148,10 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.target.append(target) if self.mode is not None and self.mode != mode: - raise ValueError('The mode of data (binary, multi-label, multi-class) should be constant, but changed' - f' between batches from {self.mode} to {mode}') + raise ValueError( + 'The mode of data (binary, multi-label, multi-class) should be constant, but changed' + f' between batches from {self.mode} to {mode}' + ) self.mode = mode def compute(self) -> torch.Tensor: @@ -163,5 +167,5 @@ def compute(self) -> torch.Tensor: self.num_classes, self.pos_label, self.average, - self.max_fpr + self.max_fpr, ) diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index d9728ed92c1a0..f6678ddd4ae75 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -68,6 +68,7 @@ class AveragePrecision(Metric): [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ + def __init__( self, num_classes: Optional[int] = None, @@ -102,10 +103,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): target: Ground truth values """ preds, target, num_classes, pos_label = _average_precision_update( - preds, - target, - self.num_classes, - self.pos_label + preds, target, self.num_classes, self.pos_label ) self.preds.append(preds) self.target.append(target) diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py index 58ab695e03565..77933ab9ba56f 100644 --- a/pytorch_lightning/metrics/classification/confusion_matrix.py +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -70,6 +70,7 @@ class ConfusionMatrix(Metric): [1., 1.]]) """ + def __init__( self, num_classes: int, diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index 0acf4c5adcda0..e22ebc2d1c16b 100755 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -87,7 +87,9 @@ def __init__( process_group: Optional[Any] = None, ): super().__init__( - compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, ) self.num_classes = num_classes @@ -98,8 +100,10 @@ def __init__( allowed_average = ("micro", "macro", "weighted", None) if self.average not in allowed_average: - raise ValueError('Argument `average` expected to be one of the following:' - f' {allowed_average} but got {self.average}') + raise ValueError( + 'Argument `average` expected to be one of the following:' + f' {allowed_average} but got {self.average}' + ) self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") @@ -125,8 +129,9 @@ def compute(self) -> torch.Tensor: """ Computes fbeta over state. """ - return _fbeta_compute(self.true_positives, self.predicted_positives, - self.actual_positives, self.beta, self.average) + return _fbeta_compute( + self.true_positives, self.predicted_positives, self.actual_positives, self.beta, self.average + ) class F1(FBeta): diff --git a/pytorch_lightning/metrics/classification/iou.py b/pytorch_lightning/metrics/classification/iou.py index 84e6803ad19e8..40567a40c807a 100644 --- a/pytorch_lightning/metrics/classification/iou.py +++ b/pytorch_lightning/metrics/classification/iou.py @@ -78,15 +78,15 @@ class IoU(ConfusionMatrix): """ def __init__( - self, - num_classes: int, - ignore_index: Optional[int] = None, - absent_score: float = 0.0, - threshold: float = 0.5, - reduction: str = 'elementwise_mean', - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, + self, + num_classes: int, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + threshold: float = 0.5, + reduction: str = 'elementwise_mean', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, ): super().__init__( num_classes=num_classes, diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 8a8c49c6fdd29..4f81c7283e202 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -82,6 +82,7 @@ class PrecisionRecallCurve(Metric): [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ + def __init__( self, num_classes: Optional[int] = None, @@ -116,18 +117,17 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): target: Ground truth values """ preds, target, num_classes, pos_label = _precision_recall_curve_update( - preds, - target, - self.num_classes, - self.pos_label + preds, target, self.num_classes, self.pos_label ) self.preds.append(preds) self.target.append(target) self.num_classes = num_classes self.pos_label = pos_label - def compute(self) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + def compute( + self + ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: """ Compute the precision-recall curve diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index d0c7ca57944ca..a5ff459f67be1 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -81,6 +81,7 @@ class ROC(Metric): tensor([1.7500, 0.7500, 0.0500])] """ + def __init__( self, num_classes: Optional[int] = None, @@ -114,19 +115,16 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): preds: Predictions from model target: Ground truth values """ - preds, target, num_classes, pos_label = _roc_update( - preds, - target, - self.num_classes, - self.pos_label - ) + preds, target, num_classes, pos_label = _roc_update(preds, target, self.num_classes, self.pos_label) self.preds.append(preds) self.target.append(target) self.num_classes = num_classes self.pos_label = pos_label - def compute(self) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + def compute( + self + ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: """ Compute the receiver operating characteristic diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index ef01929ba9791..3d956030a6140 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -165,7 +165,7 @@ def __init__( if reduce == "micro": zeros_shape = [] elif reduce == "macro": - zeros_shape = (num_classes,) + zeros_shape = (num_classes, ) default, reduce_fn = lambda: torch.zeros(zeros_shape, dtype=torch.long), "sum" else: default, reduce_fn = lambda: [], None diff --git a/pytorch_lightning/metrics/functional/auc.py b/pytorch_lightning/metrics/functional/auc.py index d0660383167d3..bc56137d9836e 100644 --- a/pytorch_lightning/metrics/functional/auc.py +++ b/pytorch_lightning/metrics/functional/auc.py @@ -20,11 +20,15 @@ def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if x.ndim > 1 or y.ndim > 1: - raise ValueError(f'Expected both `x` and `y` tensor to be 1d, but got' - f' tensors with dimention {x.ndim} and {y.ndim}') + raise ValueError( + f'Expected both `x` and `y` tensor to be 1d, but got' + f' tensors with dimention {x.ndim} and {y.ndim}' + ) if x.numel() != y.numel(): - raise ValueError(f'Expected the same number of elements in `x` and `y`' - f' tensor but received {x.numel()} and {y.numel()}') + raise ValueError( + f'Expected the same number of elements in `x` and `y`' + f' tensor but received {x.numel()} and {y.numel()}' + ) return x, y diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py index a743043fa7739..fa8b34ea7b769 100644 --- a/pytorch_lightning/metrics/functional/auroc.py +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -46,14 +46,14 @@ def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens def _auroc_compute( - preds: torch.Tensor, - target: torch.Tensor, - mode: str, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[str] = 'macro', - max_fpr: Optional[float] = None, - sample_weights: Optional[Sequence] = None, + preds: torch.Tensor, + target: torch.Tensor, + mode: str, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[str] = 'macro', + max_fpr: Optional[float] = None, + sample_weights: Optional[Sequence] = None, ) -> torch.Tensor: # binary mode override num_classes if mode == 'binary': @@ -65,20 +65,26 @@ def _auroc_compute( raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") if LooseVersion(torch.__version__) < LooseVersion('1.6.0'): - raise RuntimeError("`max_fpr` argument requires `torch.bucketize` which" - " is not available below PyTorch version 1.6") + raise RuntimeError( + "`max_fpr` argument requires `torch.bucketize` which" + " is not available below PyTorch version 1.6" + ) # max_fpr parameter is only support for binary if mode != 'binary': - raise ValueError(f"Partial AUC computation not available in " - f"multilabel/multiclass setting, 'max_fpr' must be" - f" set to `None`, received `{max_fpr}`.") + raise ValueError( + f"Partial AUC computation not available in" + f" multilabel/multiclass setting, 'max_fpr' must be" + f" set to `None`, received `{max_fpr}`." + ) # calculate fpr, tpr if mode == 'multi-label': # for multilabel we iteratively evaluate roc in a binary fashion - output = [roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) - for i in range(num_classes)] + output = [ + roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) + for i in range(num_classes) + ] fpr = [o[0] for o in output] tpr = [o[1] for o in output] else: @@ -103,8 +109,10 @@ def _auroc_compute( return torch.sum(torch.stack(auc_scores) * support / support.sum()) allowed_average = [e.value for e in AverageMethods] - raise ValueError(f"Argument `average` expected to be one of the following:" - f" {allowed_average} but got {average}") + raise ValueError( + f"Argument `average` expected to be one of the following:" + f" {allowed_average} but got {average}" + ) return auc(fpr, tpr) @@ -121,19 +129,19 @@ def _auroc_compute( # McClish correction: standardize result to be 0.5 if non-discriminant # and 1 if maximal - min_area = 0.5 * max_fpr ** 2 + min_area = 0.5 * max_fpr**2 max_area = max_fpr return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) def auroc( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[str] = 'macro', - max_fpr: Optional[float] = None, - sample_weights: Optional[Sequence] = None, + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[str] = 'macro', + max_fpr: Optional[float] = None, + sample_weights: Optional[Sequence] = None, ) -> torch.Tensor: """ Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC) `_ diff --git a/pytorch_lightning/metrics/functional/average_precision.py b/pytorch_lightning/metrics/functional/average_precision.py index 026ea27aa11af..49dc6fed9cec6 100644 --- a/pytorch_lightning/metrics/functional/average_precision.py +++ b/pytorch_lightning/metrics/functional/average_precision.py @@ -22,20 +22,20 @@ def _average_precision_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, int, int]: return _precision_recall_curve_update(preds, target, num_classes, pos_label) def _average_precision_compute( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - pos_label: int, - sample_weights: Optional[Sequence] = None + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + pos_label: int, + sample_weights: Optional[Sequence] = None ) -> Union[List[torch.Tensor], torch.Tensor]: precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) # Return the step function integral @@ -51,11 +51,11 @@ def _average_precision_compute( def average_precision( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - sample_weights: Optional[Sequence] = None, + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """ Computes the average precision score. diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 8ab94edd2145d..fae9e0770f88d 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -32,8 +32,8 @@ def to_onehot( - tensor: torch.Tensor, - num_classes: Optional[int] = None, + tensor: torch.Tensor, + num_classes: Optional[int] = None, ) -> torch.Tensor: """ Converts a dense label tensor to one-hot format @@ -48,10 +48,7 @@ def to_onehot( return __to(tensor, num_classes) -def to_categorical( - tensor: torch.Tensor, - argmax_dim: int = 1 -) -> torch.Tensor: +def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ Converts a tensor of probabilities to a dense label tensor @@ -67,9 +64,9 @@ def to_categorical( def get_num_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, ) -> int: """ Calculates the number of classes for a given prediction and target tensor. @@ -86,9 +83,10 @@ def get_num_classes( def stat_scores( - pred: torch.Tensor, - target: torch.Tensor, - class_index: int, argmax_dim: int = 1, + pred: torch.Tensor, + target: torch.Tensor, + class_index: int, + argmax_dim: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculates the number of true positive, false positive, true negative @@ -126,11 +124,11 @@ def stat_scores( def stat_scores_multiple_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - argmax_dim: int = 1, - reduction: str = 'none', + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + argmax_dim: int = 1, + reduction: str = 'none', ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculates the number of true positive, false positive, true negative @@ -160,13 +158,13 @@ def stat_scores_multiple_classes( raise ValueError("reduction type %s not supported" % reduction) if reduction == 'none': - pred = pred.view((-1,)).long() - target = target.view((-1,)).long() + pred = pred.view((-1, )).long() + target = target.view((-1, )).long() - tps = torch.zeros((num_classes + 1,), device=pred.device) - fps = torch.zeros((num_classes + 1,), device=pred.device) - fns = torch.zeros((num_classes + 1,), device=pred.device) - sups = torch.zeros((num_classes + 1,), device=pred.device) + tps = torch.zeros((num_classes + 1, ), device=pred.device) + fps = torch.zeros((num_classes + 1, ), device=pred.device) + fns = torch.zeros((num_classes + 1, ), device=pred.device) + sups = torch.zeros((num_classes + 1, ), device=pred.device) match_true = (pred == target).float() match_false = 1 - match_true @@ -214,12 +212,12 @@ def _confmat_normalize(cm): def precision_recall( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - return_support: bool = False, - return_state: bool = False + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', + return_support: bool = False, + return_state: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes precision and recall for different thresholds @@ -272,10 +270,10 @@ def precision_recall( def precision( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes precision score. @@ -311,15 +309,14 @@ def precision( " It will be removed in v1.4.0", DeprecationWarning ) - return precision_recall(pred=pred, target=target, - num_classes=num_classes, class_reduction=class_reduction)[0] + return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] def recall( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes recall score. @@ -354,16 +351,15 @@ def recall( " It will be removed in v1.4.0", DeprecationWarning ) - return precision_recall(pred=pred, target=target, - num_classes=num_classes, class_reduction=class_reduction)[1] + return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1] # todo: remove in 1.3 def roc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. @@ -380,10 +376,10 @@ def roc( # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py def _roc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. @@ -431,10 +427,10 @@ def _roc( # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py def multiclass_roc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. @@ -481,8 +477,8 @@ def multiclass_roc( def auc( - x: torch.Tensor, - y: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, ) -> torch.Tensor: """ Computes Area Under the Curve (AUC) using the trapezoidal rule @@ -514,12 +510,10 @@ def auc( def auc_decorator() -> Callable: - rank_zero_warn( - "This `auc_decorator` was deprecated in v1.2.0." - " It will be removed in v1.4.0", DeprecationWarning - ) + rank_zero_warn("This `auc_decorator` was deprecated in v1.2.0." " It will be removed in v1.4.0", DeprecationWarning) def wrapper(func_to_decorate: Callable) -> Callable: + @wraps(func_to_decorate) def new_func(*args, **kwargs) -> torch.Tensor: x, y = func_to_decorate(*args, **kwargs)[:2] @@ -538,6 +532,7 @@ def multiclass_auc_decorator() -> Callable: ) def wrapper(func_to_decorate: Callable) -> Callable: + @wraps(func_to_decorate) def new_func(*args, **kwargs) -> torch.Tensor: results = [] @@ -553,11 +548,11 @@ def new_func(*args, **kwargs) -> torch.Tensor: def auroc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., - max_fpr: float = None, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., + max_fpr: float = None, ) -> torch.Tensor: """ Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores @@ -589,15 +584,16 @@ def auroc( " `pytorch_lightning.metrics.functional.auroc import auroc`." " It will be removed in v1.4.0", DeprecationWarning ) - return __auroc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, max_fpr=max_fpr, - num_classes=1) + return __auroc( + preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, max_fpr=max_fpr, num_classes=1 + ) def multiclass_auroc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, ) -> torch.Tensor: """ Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from multiclass @@ -635,30 +631,33 @@ def multiclass_auroc( if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)): raise ValueError( "Multiclass AUROC metric expects the target scores to be" - " probabilities, i.e. they should sum up to 1.0 over classes") + " probabilities, i.e. they should sum up to 1.0 over classes" + ) if torch.unique(target).size(0) != pred.size(1): raise ValueError( f"Number of classes found in in 'target' ({torch.unique(target).size(0)})" f" does not equal the number of columns in 'pred' ({pred.size(1)})." " Multiclass AUROC is not defined when all of the classes do not" - " occur in the target labels.") + " occur in the target labels." + ) if num_classes is not None and num_classes != pred.size(1): raise ValueError( f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal" - f" the number of classes passed in 'num_classes' ({num_classes}).") + f" the number of classes passed in 'num_classes' ({num_classes})." + ) return __auroc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes) def dice_score( - pred: torch.Tensor, - target: torch.Tensor, - bg: bool = False, - nan_score: float = 0.0, - no_fg_score: float = 0.0, - reduction: str = 'elementwise_mean', + pred: torch.Tensor, + target: torch.Tensor, + bg: bool = False, + nan_score: float = 0.0, + no_fg_score: float = 0.0, + reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ Compute dice score from prediction scores @@ -709,12 +708,12 @@ def dice_score( # todo: remove in 1.4 def iou( - pred: torch.Tensor, - target: torch.Tensor, - ignore_index: Optional[int] = None, - absent_score: float = 0.0, - num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + pred: torch.Tensor, + target: torch.Tensor, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + num_classes: Optional[int] = None, + reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ Intersection over union, or Jaccard index calculation. @@ -772,10 +771,10 @@ def iou( # todo: remove in 1.3 def precision_recall_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., ): """ Computes precision-recall pairs for different thresholds. @@ -793,10 +792,10 @@ def precision_recall_curve( # todo: remove in 1.3 def multiclass_precision_recall_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, ): """ Computes precision-recall pairs for different thresholds given a multiclass scores. @@ -816,10 +815,10 @@ def multiclass_precision_recall_curve( # todo: remove in 1.3 def average_precision( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., ): """ Compute average precision from prediction scores. diff --git a/pytorch_lightning/metrics/functional/confusion_matrix.py b/pytorch_lightning/metrics/functional/confusion_matrix.py index c6db7e1e45da4..1810af579653e 100644 --- a/pytorch_lightning/metrics/functional/confusion_matrix.py +++ b/pytorch_lightning/metrics/functional/confusion_matrix.py @@ -19,22 +19,20 @@ from pytorch_lightning.utilities import rank_zero_warn -def _confusion_matrix_update(preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - threshold: float = 0.5) -> torch.Tensor: +def _confusion_matrix_update( + preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5 +) -> torch.Tensor: preds, target, mode = _input_format_classification(preds, target, threshold) if mode not in ('binary', 'multi-label'): preds = preds.argmax(dim=1) target = target.argmax(dim=1) unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long) - bins = torch.bincount(unique_mapping, minlength=num_classes ** 2) + bins = torch.bincount(unique_mapping, minlength=num_classes**2) confmat = bins.reshape(num_classes, num_classes) return confmat -def _confusion_matrix_compute(confmat: torch.Tensor, - normalize: Optional[str] = None) -> torch.Tensor: +def _confusion_matrix_compute(confmat: torch.Tensor, normalize: Optional[str] = None) -> torch.Tensor: allowed_normalize = ('true', 'pred', 'all', None) assert normalize in allowed_normalize, \ f"Argument average needs to one of the following: {allowed_normalize}" @@ -55,11 +53,11 @@ def _confusion_matrix_compute(confmat: torch.Tensor, def confusion_matrix( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - normalize: Optional[str] = None, - threshold: float = 0.5 + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + normalize: Optional[str] = None, + threshold: float = 0.5 ) -> torch.Tensor: """ Computes the confusion matrix. Works with binary, multiclass, and multilabel data. diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index 9309e6ef34ad9..617d800c754e3 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -24,15 +24,15 @@ def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tup def _explained_variance_compute( - preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = 'uniform_average', ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: diff_avg = torch.mean(target - preds, dim=0) - numerator = torch.mean((target - preds - diff_avg) ** 2, dim=0) + numerator = torch.mean((target - preds - diff_avg)**2, dim=0) target_avg = torch.mean(target, dim=0) - denominator = torch.mean((target - target_avg) ** 2, dim=0) + denominator = torch.mean((target - target_avg)**2, dim=0) # Take care of division by zero nonzero_numerator = numerator != 0 @@ -54,9 +54,9 @@ def _explained_variance_compute( def explained_variance( - preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = 'uniform_average', ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: """ Computes explained variance. diff --git a/pytorch_lightning/metrics/functional/f_beta.py b/pytorch_lightning/metrics/functional/f_beta.py index c294d29805a6f..07633e8174db1 100755 --- a/pytorch_lightning/metrics/functional/f_beta.py +++ b/pytorch_lightning/metrics/functional/f_beta.py @@ -19,15 +19,13 @@ def _fbeta_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - threshold: float = 0.5, - multilabel: bool = False + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + threshold: float = 0.5, + multilabel: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - preds, target = _input_format_classification_one_hot( - num_classes, preds, target, threshold, multilabel - ) + preds, target = _input_format_classification_one_hot(num_classes, preds, target, threshold, multilabel) true_positives = torch.sum(preds * target, dim=1) predicted_positives = torch.sum(preds, dim=1) actual_positives = torch.sum(target, dim=1) @@ -35,11 +33,11 @@ def _fbeta_update( def _fbeta_compute( - true_positives: torch.Tensor, - predicted_positives: torch.Tensor, - actual_positives: torch.Tensor, - beta: float = 1.0, - average: str = "micro" + true_positives: torch.Tensor, + predicted_positives: torch.Tensor, + actual_positives: torch.Tensor, + beta: float = 1.0, + average: str = "micro" ) -> torch.Tensor: if average == "micro": precision = true_positives.sum().float() / predicted_positives.sum() @@ -48,19 +46,19 @@ def _fbeta_compute( precision = true_positives.float() / predicted_positives recall = true_positives.float() / actual_positives - num = (1 + beta ** 2) * precision * recall - denom = beta ** 2 * precision + recall + num = (1 + beta**2) * precision * recall + denom = beta**2 * precision + recall return class_reduce(num, denom, weights=actual_positives, class_reduction=average) def fbeta( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - beta: float = 1.0, - threshold: float = 0.5, - average: str = "micro", - multilabel: bool = False + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + beta: float = 1.0, + threshold: float = 0.5, + average: str = "micro", + multilabel: bool = False ) -> torch.Tensor: """ Computes f_beta metric. @@ -107,12 +105,12 @@ def fbeta( def f1( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - threshold: float = 0.5, - average: str = "micro", - multilabel: bool = False + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + threshold: float = 0.5, + average: str = "micro", + multilabel: bool = False ) -> torch.Tensor: """ Computes F1 metric. F1 metrics correspond to a equally weighted average of the diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index c2cac673024a7..60409751fc9f0 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -19,7 +19,9 @@ def _hamming_distance_update( - preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, ) -> Tuple[torch.Tensor, int]: preds, target, _ = _input_format_classification(preds, target, threshold=threshold) diff --git a/pytorch_lightning/metrics/functional/iou.py b/pytorch_lightning/metrics/functional/iou.py index f9dcb4b1b401f..1f539215ccd59 100644 --- a/pytorch_lightning/metrics/functional/iou.py +++ b/pytorch_lightning/metrics/functional/iou.py @@ -21,11 +21,11 @@ def _iou_from_confmat( - confmat: torch.Tensor, - num_classes: int, - ignore_index: Optional[int] = None, - absent_score: float = 0.0, - reduction: str = 'elementwise_mean', + confmat: torch.Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + reduction: str = 'elementwise_mean', ): intersection = torch.diag(confmat) union = confmat.sum(0) + confmat.sum(1) - intersection @@ -44,13 +44,13 @@ def _iou_from_confmat( def iou( - pred: torch.Tensor, - target: torch.Tensor, - ignore_index: Optional[int] = None, - absent_score: float = 0.0, - threshold: float = 0.5, - num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + pred: torch.Tensor, + target: torch.Tensor, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + threshold: float = 0.5, + num_classes: Optional[int] = None, + reduction: str = 'elementwise_mean', ) -> torch.Tensor: r""" Computes `Intersection over union, or Jaccard index calculation `_: diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py index ef8c1c289e2dc..57c6fc1ece4e5 100644 --- a/pytorch_lightning/metrics/functional/nlp.py +++ b/pytorch_lightning/metrics/functional/nlp.py @@ -45,10 +45,10 @@ def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: def bleu_score( - translate_corpus: Sequence[str], - reference_corpus: Sequence[str], - n_gram: int = 4, - smooth: bool = False + translate_corpus: Sequence[str], + reference_corpus: Sequence[str], + n_gram: int = 4, + smooth: bool = False ) -> torch.Tensor: """ Calculate BLEU score of machine translated text with one or more references diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py index 089148bf23190..4eab13e6bbb88 100644 --- a/pytorch_lightning/metrics/functional/precision_recall_curve.py +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -20,10 +20,10 @@ def _binary_clf_curve( - preds: torch.Tensor, - target: torch.Tensor, - sample_weights: Optional[Sequence] = None, - pos_label: int = 1., + preds: torch.Tensor, + target: torch.Tensor, + sample_weights: Optional[Sequence] = None, + pos_label: int = 1., ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py @@ -63,15 +63,13 @@ def _binary_clf_curve( def _precision_recall_curve_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, int, int]: if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError( - "preds and target must have same number of dimensions, or one additional dimension for preds" - ) + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") # single class evaluation if len(preds.shape) == len(target.shape): num_classes = 1 @@ -84,12 +82,16 @@ def _precision_recall_curve_update( # multi class evaluation if len(preds.shape) == len(target.shape) + 1: if pos_label is not None: - rank_zero_warn('Argument `pos_label` should be `None` when running' - f'multiclass precision recall curve. Got {pos_label}') + rank_zero_warn( + 'Argument `pos_label` should be `None` when running' + f' multiclass precision recall curve. Got {pos_label}' + ) if num_classes != preds.shape[1]: - raise ValueError(f'Argument `num_classes` was set to {num_classes} in' - f'metric `precision_recall_curve` but detected {preds.shape[1]}' - 'number of classes from predictions') + raise ValueError( + f'Argument `num_classes` was set to {num_classes} in' + f' metric `precision_recall_curve` but detected {preds.shape[1]}' + ' number of classes from predictions' + ) preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) target = target.flatten() @@ -97,20 +99,17 @@ def _precision_recall_curve_update( def _precision_recall_curve_compute( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - pos_label: int, - sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + pos_label: int, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: if num_classes == 1: fps, tps, thresholds = _binary_clf_curve( - preds=preds, - target=target, - sample_weights=sample_weights, - pos_label=pos_label + preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label ) precision = tps / (tps + fps) @@ -123,13 +122,9 @@ def _precision_recall_curve_compute( # need to call reversed explicitly, since including that to slice would # introduce negative strides that are not yet supported in pytorch - precision = torch.cat([reversed(precision[sl]), - torch.ones(1, dtype=precision.dtype, - device=precision.device)]) + precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) - recall = torch.cat([reversed(recall[sl]), - torch.zeros(1, dtype=recall.dtype, - device=recall.device)]) + recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) thresholds = reversed(thresholds[sl]).clone() @@ -154,13 +149,13 @@ def _precision_recall_curve_compute( def precision_recall_curve( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: """ Computes precision-recall pairs for different thresholds. @@ -215,6 +210,5 @@ def precision_recall_curve( [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ - preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, - num_classes, pos_label) + preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py index fb7f9e47e1afa..c0e95a14bfcd5 100644 --- a/pytorch_lightning/metrics/functional/psnr.py +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -1,4 +1,3 @@ - from typing import Optional, Tuple import torch diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index 82117dd688064..ef8a20c806ee9 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -20,13 +20,15 @@ def _r2score_update( - preds: torch.tensor, - target: torch.Tensor, + preds: torch.tensor, + target: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: _check_same_shape(preds, target) if preds.ndim > 2: - raise ValueError('Expected both prediction and target to be 1D or 2D tensors,' - f' but recevied tensors with dimension {preds.shape}') + raise ValueError( + 'Expected both prediction and target to be 1D or 2D tensors,' + f' but recevied tensors with dimension {preds.shape}' + ) if len(preds) < 2: raise ValueError('Needs atleast two samples to calculate r2 score.') @@ -38,12 +40,14 @@ def _r2score_update( return sum_squared_error, sum_error, residual, total -def _r2score_compute(sum_squared_error: torch.Tensor, - sum_error: torch.Tensor, - residual: torch.Tensor, - total: torch.Tensor, - adjusted: int = 0, - multioutput: str = "uniform_average") -> torch.Tensor: +def _r2score_compute( + sum_squared_error: torch.Tensor, + sum_error: torch.Tensor, + residual: torch.Tensor, + total: torch.Tensor, + adjusted: int = 0, + multioutput: str = "uniform_average" +) -> torch.Tensor: mean_error = sum_error / total diff = sum_squared_error - sum_error * mean_error raw_scores = 1 - (residual / diff) @@ -56,31 +60,32 @@ def _r2score_compute(sum_squared_error: torch.Tensor, diff_sum = torch.sum(diff) r2score = torch.sum(diff / diff_sum * raw_scores) else: - raise ValueError('Argument `multioutput` must be either `raw_values`,' - f' `uniform_average` or `variance_weighted`. Received {multioutput}.') + raise ValueError( + 'Argument `multioutput` must be either `raw_values`,' + f' `uniform_average` or `variance_weighted`. Received {multioutput}.' + ) if adjusted < 0 or not isinstance(adjusted, int): - raise ValueError('`adjusted` parameter should be an integer larger or' - ' equal to 0.') + raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.') if adjusted != 0: if adjusted > total - 1: - rank_zero_warn("More independent regressions than datapoints in" - " adjusted r2 score. Falls back to standard r2 score.", - UserWarning) + rank_zero_warn( + "More independent regressions than datapoints in" + " adjusted r2 score. Falls back to standard r2 score.", UserWarning + ) elif adjusted == total - 1: - rank_zero_warn("Division by zero in adjusted r2 score. Falls back to" - " standard r2 score.", UserWarning) + rank_zero_warn("Division by zero in adjusted r2 score. Falls back to" " standard r2 score.", UserWarning) else: r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1) return r2score def r2score( - preds: torch.Tensor, - target: torch.Tensor, - adjusted: int = 0, - multioutput: str = "uniform_average", + preds: torch.Tensor, + target: torch.Tensor, + adjusted: int = 0, + multioutput: str = "uniform_average", ) -> torch.Tensor: r""" Computes r2 score also known as `coefficient of determination diff --git a/pytorch_lightning/metrics/functional/roc.py b/pytorch_lightning/metrics/functional/roc.py index bf07d2799daba..16ecf18b91e11 100644 --- a/pytorch_lightning/metrics/functional/roc.py +++ b/pytorch_lightning/metrics/functional/roc.py @@ -22,29 +22,26 @@ def _roc_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, int, int]: return _precision_recall_curve_update(preds, target, num_classes, pos_label) def _roc_compute( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - pos_label: int, - sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + pos_label: int, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: if num_classes == 1: fps, tps, thresholds = _binary_clf_curve( - preds=preds, - target=target, - sample_weights=sample_weights, - pos_label=pos_label + preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label ) # Add an extra threshold position # to make sure that the curve starts at (0, 0) @@ -81,13 +78,13 @@ def _roc_compute( def roc( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]]]: """ Computes the Receiver Operating Characteristic (ROC). diff --git a/pytorch_lightning/metrics/functional/self_supervised.py b/pytorch_lightning/metrics/functional/self_supervised.py index a6952ee6f678d..de70ae5335f31 100644 --- a/pytorch_lightning/metrics/functional/self_supervised.py +++ b/pytorch_lightning/metrics/functional/self_supervised.py @@ -15,10 +15,10 @@ def embedding_similarity( - batch: torch.Tensor, - similarity: str = 'cosine', - reduction: str = 'none', - zero_diagonal: bool = True + batch: torch.Tensor, + similarity: str = 'cosine', + reduction: str = 'none', + zero_diagonal: bool = True ) -> torch.Tensor: """ Computes representation similarity diff --git a/pytorch_lightning/metrics/functional/ssim.py b/pytorch_lightning/metrics/functional/ssim.py index a978ce8268161..e0a1d97ff5fd1 100644 --- a/pytorch_lightning/metrics/functional/ssim.py +++ b/pytorch_lightning/metrics/functional/ssim.py @@ -25,8 +25,9 @@ def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.de return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) -def _gaussian_kernel(channel: int, kernel_size: Sequence[int], sigma: Sequence[float], - dtype: torch.dtype, device: torch.device): +def _gaussian_kernel( + channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device +): gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device) gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device) kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) @@ -92,7 +93,7 @@ def _ssim_compute( input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) outputs = F.conv2d(input_list, kernel, groups=channel) - output_list = [outputs[x * preds.size(0): (x + 1) * preds.size(0)] for x in range(len(outputs))] + output_list = [outputs[x * preds.size(0):(x + 1) * preds.size(0)] for x in range(len(outputs))] mu_pred_sq = output_list[0].pow(2) mu_target_sq = output_list[1].pow(2) diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index cabd368877bad..c75af2ec4cab0 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -21,11 +21,13 @@ def _del_column(tensor: torch.Tensor, index: int): """ Delete the column at index.""" - return torch.cat([tensor[:, :index], tensor[:, (index + 1) :]], 1) + return torch.cat([tensor[:, :index], tensor[:, (index + 1):]], 1) def _stat_scores( - preds: torch.Tensor, target: torch.Tensor, reduce: str = "micro" + preds: torch.Tensor, + target: torch.Tensor, + reduce: str = "micro", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Calculate the number of tp, fp, tn, fn. diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 794e696e98f8a..e5f8a0fd48cd6 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -122,8 +122,7 @@ def add_state( """ if ( - not isinstance(default, torch.Tensor) - and not isinstance(default, list) # noqa: W503 + not isinstance(default, torch.Tensor) and not isinstance(default, list) # noqa: W503 or (isinstance(default, list) and len(default) != 0) # noqa: W503 ): raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)") @@ -193,6 +192,7 @@ def _sync_dist(self, dist_sync_fn=gather_all_tensors): setattr(self, attr, reduced) def _wrap_update(self, update): + @functools.wraps(update) def wrapped_func(*args, **kwargs): self._computed = None @@ -201,6 +201,7 @@ def wrapped_func(*args, **kwargs): return wrapped_func def _wrap_compute(self, compute): + @functools.wraps(compute) def wrapped_func(*args, **kwargs): # return cached value @@ -314,8 +315,7 @@ def _filter_kwargs(self, **kwargs): _params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) filtered_kwargs = { k: v - for k, v in kwargs.items() - if k in self._update_signature.parameters.keys() + for k, v in kwargs.items() if k in self._update_signature.parameters.keys() and self._update_signature.parameters[k].kind not in _params } @@ -544,14 +544,16 @@ def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric] for name, metric in metrics.items(): if not isinstance(metric, Metric): raise ValueError( - f"Value {metric} belonging to key {name}" " is not an instance of `pl.metrics.Metric`" + f"Value {metric} belonging to key {name}" + " is not an instance of `pl.metrics.Metric`" ) self[name] = metric elif isinstance(metrics, (tuple, list)): for metric in metrics: if not isinstance(metric, Metric): raise ValueError( - f"Input {metric} to `MetricCollection` is not a instance" " of `pl.metrics.Metric`" + f"Input {metric} to `MetricCollection` is not a instance" + " of `pl.metrics.Metric`" ) name = metric.__class__.__name__ if name in self: diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index 44f4d33898f29..77f6c1363a566 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -81,6 +81,7 @@ class R2Score(Metric): >>> r2score(preds, target) tensor([0.9654, 0.9082]) """ + def __init__( self, num_outputs: int = 1, @@ -101,8 +102,7 @@ def __init__( self.num_outputs = num_outputs if adjusted < 0 or not isinstance(adjusted, int): - raise ValueError('`adjusted` parameter should be an integer larger or' - ' equal to 0.') + raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.') self.adjusted = adjusted allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') @@ -136,5 +136,6 @@ def compute(self) -> torch.Tensor: """ Computes r2 score over the metric states. """ - return _r2score_compute(self.sum_squared_error, self.sum_error, self.residual, - self.total, self.adjusted, self.multioutput) + return _r2score_compute( + self.sum_squared_error, self.sum_error, self.residual, self.total, self.adjusted, self.multioutput + ) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 2a8d7fee0fa91..7a4dc726ea555 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -44,7 +44,11 @@ def _check_same_shape(pred: torch.Tensor, target: torch.Tensor): def _input_format_classification_one_hot( - num_classes: int, preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, multilabel: bool = False + num_classes: int, + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, + multilabel: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """Convert preds and target tensors into one hot spare label tensors @@ -262,7 +266,8 @@ def class_reduce( return fraction raise ValueError( - f"Reduction parameter {class_reduction} unknown." f" Choose between one of these: {valid_reduction}" + f"Reduction parameter {class_reduction} unknown." + f" Choose between one of these: {valid_reduction}" ) diff --git a/setup.cfg b/setup.cfg index dc351647a9986..35fea551d77a8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -114,6 +114,7 @@ max-line-length = 120 based_on_style = pep8 spaces_before_comment = 2 split_before_logical_operator = true +split_before_arithmetic_operator = true COLUMN_LIMIT = 120 COALESCE_BRACKETS = true DEDENT_CLOSING_BRACKETS = true