From 85310c13d9d7276c7170be0527812e22da83492d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 15 Mar 2021 20:29:46 +0100 Subject: [PATCH 01/11] docs --- pytorch_lightning/metrics/compositional.py | 7 ++-- pytorch_lightning/metrics/metric.py | 12 +++---- pytorch_lightning/metrics/utils.py | 40 ++++++++++------------ 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py index 5961714209d40..baa631bbeb25d 100644 --- a/pytorch_lightning/metrics/compositional.py +++ b/pytorch_lightning/metrics/compositional.py @@ -21,10 +21,9 @@ class CompositionalMetric(__CompositionalMetric): - r""" - This implementation refers to :class:`~torchmetrics.metric.CompositionalMetric`. - - .. warning:: This metric is deprecated, use ``torchmetrics.metric.CompositionalMetric``. Will be removed in v1.5.0. + """ + .. deprecated:: + Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0. """ def __init__( diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 145a13a251250..f856cb39e3da4 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -21,9 +21,8 @@ class Metric(__Metric): r""" - This implementation refers to :class:`~torchmetrics.Metric`. - - .. warning:: This metric is deprecated, use ``torchmetrics.Metric``. Will be removed in v1.5.0. + .. deprecated:: + Use :class:`torchmetrics.Metric`. Will be removed in v1.5.0. """ def __init__( @@ -46,10 +45,9 @@ def __init__( class MetricCollection(__MetricCollection): - r""" - This implementation refers to :class:`~torchmetrics.MetricCollection`. - - .. warning:: This metric is deprecated, use ``torchmetrics.MetricCollection``. Will be removed in v1.5.0. + """ + .. deprecated:: + Use :class:`torchmetrics.MetricCollection`. Will be removed in v1.5.0. """ def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 63c6892cb2987..e751e30dc1284 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -78,8 +78,9 @@ def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]: def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: - r""" - .. warning:: This function is deprecated, use ``torchmetrics.utilities.data.to_onehot``. Will be removed in v1.5.0. + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.data.to_onehot`. Will be removed in v1.5.0. """ rank_zero_warn( "This `to_onehot` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.to_onehot`." @@ -89,10 +90,9 @@ def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: - r""" - .. warning:: - - This function is deprecated, use ``torchmetrics.utilities.data.select_topk``. Will be removed in v1.5.0. + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.data.select_topk`. Will be removed in v1.5.0. """ rank_zero_warn( "This `select_topk` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.select_topk`." @@ -102,10 +102,9 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: - r""" - .. warning:: - - This function is deprecated, use ``torchmetrics.utilities.data.to_categorical``. Will be removed in v1.5.0. + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.data.to_categorical`. Will be removed in v1.5.0. """ rank_zero_warn( "This `to_categorical` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.to_categorical`." @@ -115,10 +114,9 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: - r""" - .. warning:: - - This function is deprecated, use ``torchmetrics.utilities.data.get_num_classes``. Will be removed in v1.5.0. + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.data.get_num_classes`. Will be removed in v1.5.0. """ rank_zero_warn( "This `get_num_classes` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.get_num_classes`." @@ -128,10 +126,9 @@ def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optio def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: - r""" - .. warning:: - - This function is deprecated, use ``torchmetrics.utilities.reduce``. Will be removed in v1.5.0. + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.reduce`. Will be removed in v1.5.0. """ rank_zero_warn( "This `reduce` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.reduce`." @@ -143,10 +140,9 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: def class_reduce( num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" ) -> torch.Tensor: - r""" - .. warning:: - - This function is deprecated, use ``torchmetrics.utilities.class_reduce``. Will be removed in v1.5.0. + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.class_reduce`. Will be removed in v1.5.0. """ rank_zero_warn( "This `class_reduce` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.class_reduce`." From 069128261a21f1e25e7e19bb7110e213e1e81cb2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 00:34:24 +0100 Subject: [PATCH 02/11] wrapper --- pytorch_lightning/metrics/utils.py | 56 +++++--------------- pytorch_lightning/utilities/deprecation.py | 61 ++++++++++++++++++++++ tests/utilities/test_deprecation.py | 20 +++++++ 3 files changed, 94 insertions(+), 43 deletions(-) create mode 100644 pytorch_lightning/utilities/deprecation.py create mode 100644 tests/utilities/test_deprecation.py diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index e751e30dc1284..e5ec25a9ad556 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -24,28 +24,22 @@ from torchmetrics.utilities.distributed import class_reduce as __class_reduce from torchmetrics.utilities.distributed import reduce as __reduce -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated_func +@deprecated_func(target_func=__dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_cat(x): - rank_zero_warn( - "This `dim_zero_cat` was deprecated since v1.3.0 and it will be removed in v1.5.0", DeprecationWarning - ) - return __dim_zero_cat(x) + pass +@deprecated_func(target_func=__dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_sum(x): - rank_zero_warn( - "This `dim_zero_sum` was deprecated since v1.3.0 and it will be removed in v1.5.0", DeprecationWarning - ) - return __dim_zero_sum(x) + pass +@deprecated_func(target_func=__dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_mean(x): - rank_zero_warn( - "This `dim_zero_mean` was deprecated since v1.3.0 and it will be removed in v1.5.0", DeprecationWarning - ) - return __dim_zero_mean(x) + pass def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]: @@ -77,66 +71,47 @@ def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]: return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()] +@deprecated_func(target_func=__to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0") def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: """ .. deprecated:: Use :func:`torchmetrics.utilities.data.to_onehot`. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `to_onehot` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.to_onehot`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __to_onehot(label_tensor=label_tensor, num_classes=num_classes) +@deprecated_func(target_func=__select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0") def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: """ .. deprecated:: Use :func:`torchmetrics.utilities.data.select_topk`. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `select_topk` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.select_topk`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __select_topk(prob_tensor=prob_tensor, topk=topk, dim=dim) +@deprecated_func(target_func=__to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0") def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ .. deprecated:: Use :func:`torchmetrics.utilities.data.to_categorical`. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `to_categorical` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.to_categorical`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __to_categorical(tensor=tensor, argmax_dim=argmax_dim) +@deprecated_func(target_func=__get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0") def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: """ .. deprecated:: Use :func:`torchmetrics.utilities.data.get_num_classes`. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `get_num_classes` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.get_num_classes`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __get_num_classes(pred=pred, target=target, num_classes=num_classes) +@deprecated_func(target_func=__reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ .. deprecated:: Use :func:`torchmetrics.utilities.reduce`. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `reduce` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.reduce`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __reduce(to_reduce=to_reduce, reduction=reduction) +@deprecated_func(target_func=__class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") def class_reduce( num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" ) -> torch.Tensor: @@ -144,8 +119,3 @@ def class_reduce( .. deprecated:: Use :func:`torchmetrics.utilities.class_reduce`. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `class_reduce` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.class_reduce`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __class_reduce(num=num, denom=denom, weights=weights, class_reduction=class_reduction) diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py new file mode 100644 index 0000000000000..336d833e229cd --- /dev/null +++ b/pytorch_lightning/utilities/deprecation.py @@ -0,0 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from functools import wraps +from typing import Any, Callable, List, Tuple + +from pytorch_lightning.utilities import rank_zero_warn + + +def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]]: + """Parse function arguments, types and default values + + Example: + >>> get_func_arguments_and_types(get_func_arguments_and_types) + [('func', typing.Callable, )] + """ + func_default_params = inspect.signature(func).parameters + name_type_default = [] + for arg in func_default_params: + arg_type = func_default_params[arg].annotation + arg_default = func_default_params[arg].default + name_type_default.append((arg, arg_type, arg_default)) + return name_type_default + + +def deprecated_func(target_func: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable: + + def inner_function(func): + + @wraps(func) + def wrapper(*args, **kwargs): + target_func_str = f'{target_func.__module__}.{target_func.__name__}' + rank_zero_warn( + f"This `{func.__name__}` was deprecated since v{ver_deprecate} in favor of `{target_func_str}`." + f" It will be removed in v{ver_remove}.", DeprecationWarning + ) + if args: # in case any args passed move them to kwargs + # parse only the argument names + cls_arg_names = [arg[0] for arg in get_func_arguments_and_types(func)] + # convert args to kwargs + kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + target_args = [arg[0] for arg in get_func_arguments_and_types(target_func)] + assert all(arg in target_args for arg in kwargs), \ + "Failed mapping, arguments missing in target func: %s" % [arg not in target_args for arg in kwargs] + # all args were already moved to kwargs + return target_func(**kwargs) + + return wrapper + + return inner_function diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py new file mode 100644 index 0000000000000..730838686a060 --- /dev/null +++ b/tests/utilities/test_deprecation.py @@ -0,0 +1,20 @@ +import pytest + +from pytorch_lightning.utilities.deprecation import deprecated_func + + +def my_sum(a, b=3): + return a + b + + +@deprecated_func(target_func=my_sum, ver_deprecate="0.1", ver_remove="0.5") +def dep_sum(a, b): + pass + + +def test_deprecated_func(): + with pytest.deprecated_call( + match='This `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' + ' It will be removed in v0.5.' + ): + assert dep_sum(2, b=5) == 7 From cab6e224b7eb250b212e7ee2004760237b6ccc7c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 00:44:13 +0100 Subject: [PATCH 03/11] test --- pytorch_lightning/metrics/compositional.py | 4 ++-- tests/deprecated_api/test_remove_1-5_metrics.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py index baa631bbeb25d..975b8280f77d5 100644 --- a/pytorch_lightning/metrics/compositional.py +++ b/pytorch_lightning/metrics/compositional.py @@ -33,7 +33,7 @@ def __init__( metric_b: Union[Metric, int, float, torch.Tensor, None], ): rank_zero_warn( - "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." - " It will be removed in v1.5.0", DeprecationWarning + "This `CompositionalMetric` was deprecated since v1.3.0 in favor of" + " `torchmetrics.metric.CompositionalMetric`. It will be removed in v1.5.0", DeprecationWarning ) super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b) diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index b2fa4f69f74b9..fcbeaa2e78819 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -18,6 +18,8 @@ from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot +from pytorch_lightning.metrics import MetricCollection, Accuracy + def test_v1_5_0_metrics_utils(): x = torch.tensor([1, 2, 3]) @@ -34,3 +36,14 @@ def test_v1_5_0_metrics_utils(): x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) with pytest.deprecated_call(match="It will be removed in v1.5.0"): assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int)) + + +def test_v1_5_0_metrics_collection(): + target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + with pytest.deprecated_call( + match="This `MetricCollection` was deprecated since v1.3.0 in favor of `torchmetrics.MetricCollection`." + " It will be removed in v1.5.0" + ): + metrics = MetricCollection([Accuracy()]) + assert metrics(preds, target) == {'Accuracy': torch.Tensor([0.1250])[0]} From 9b66f5c3c5328792e770e1b368b1464ccab392f1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 10:07:10 +0100 Subject: [PATCH 04/11] test --- pytorch_lightning/core/step_result.py | 2 +- .../metrics/classification/accuracy.py | 2 +- .../metrics/classification/auc.py | 2 +- .../metrics/classification/auroc.py | 2 +- .../classification/average_precision.py | 2 +- .../classification/confusion_matrix.py | 2 +- .../metrics/classification/f_beta.py | 2 +- .../classification/hamming_distance.py | 2 +- .../classification/precision_recall_curve.py | 2 +- .../metrics/classification/roc.py | 2 +- .../metrics/classification/stat_scores.py | 2 +- pytorch_lightning/metrics/metric.py | 16 +++++++-------- .../metrics/regression/explained_variance.py | 2 +- .../metrics/regression/mean_absolute_error.py | 2 +- .../metrics/regression/mean_squared_error.py | 2 +- .../regression/mean_squared_log_error.py | 2 +- pytorch_lightning/metrics/regression/psnr.py | 2 +- .../metrics/regression/r2score.py | 2 +- pytorch_lightning/metrics/regression/ssim.py | 2 +- .../metrics/retrieval/retrieval_metric.py | 2 +- pytorch_lightning/metrics/utils.py | 20 +++++++++---------- .../logger_connector/metrics_holder.py | 3 +-- pytorch_lightning/utilities/deprecation.py | 15 +++++++++++--- tests/core/test_metric_result_integration.py | 2 +- .../deprecated_api/test_remove_1-5_metrics.py | 7 +++---- .../classification/test_precision_recall.py | 3 ++- tests/metrics/retrieval/test_map.py | 2 +- tests/metrics/test_metric_lightning.py | 2 +- tests/metrics/utils.py | 3 +-- tests/utilities/test_deprecation.py | 4 ++-- 30 files changed, 60 insertions(+), 55 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index f8d7a2ffe3a23..3961586f4946a 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -20,8 +20,8 @@ import torch from torch import Tensor +from torchmetrics import Metric -from pytorch_lightning.metrics import Metric from pytorch_lightning.utilities.distributed import sync_ddp_if_available diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 343e979dd3e0c..367c9b029d841 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.accuracy import _accuracy_compute, _accuracy_update -from pytorch_lightning.metrics.metric import Metric class Accuracy(Metric): diff --git a/pytorch_lightning/metrics/classification/auc.py b/pytorch_lightning/metrics/classification/auc.py index 6c5a29173d20a..76c1959a8603a 100644 --- a/pytorch_lightning/metrics/classification/auc.py +++ b/pytorch_lightning/metrics/classification/auc.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.auc import _auc_compute, _auc_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/auroc.py b/pytorch_lightning/metrics/classification/auroc.py index 6b9b5ae9f021f..7d8ba7368e45d 100644 --- a/pytorch_lightning/metrics/classification/auroc.py +++ b/pytorch_lightning/metrics/classification/auroc.py @@ -15,9 +15,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index f9c7bde158383..adcdd86ed1ca8 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -14,9 +14,9 @@ from typing import Any, List, Optional, Union import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.average_precision import _average_precision_compute, _average_precision_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py index c3defc82bc92d..112fb4940e6e2 100644 --- a/pytorch_lightning/metrics/classification/confusion_matrix.py +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -14,9 +14,9 @@ from typing import Any, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update -from pytorch_lightning.metrics.metric import Metric class ConfusionMatrix(Metric): diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index ae01b80966868..a46b01a1aa8b7 100644 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -14,9 +14,9 @@ from typing import Any, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.f_beta import _fbeta_compute, _fbeta_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index adf1086f3c85f..dceb90c0a4ca9 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_compute, _hamming_distance_update -from pytorch_lightning.metrics.metric import Metric class HammingDistance(Metric): diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 5a02a99ed17fd..ccf821d829d78 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -14,12 +14,12 @@ from typing import Any, List, Optional, Tuple, Union import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.precision_recall_curve import ( _precision_recall_curve_compute, _precision_recall_curve_update, ) -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 598646cde3861..30ca0b4fe6925 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -14,9 +14,9 @@ from typing import Any, List, Optional, Tuple, Union import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.roc import _roc_compute, _roc_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index 3807d7079b508..672b0f41c6fc5 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional, Tuple import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_compute, _stat_scores_update -from pytorch_lightning.metrics.metric import Metric class StatScores(Metric): diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index f856cb39e3da4..d4d621b158303 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -13,13 +13,14 @@ # limitations under the License. from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from torchmetrics import Metric as __Metric -from torchmetrics import MetricCollection as __MetricCollection +from torchmetrics import Metric as _Metric +from torchmetrics.collections import MetricCollection as _MetricCollection +from pytorch_lightning.utilities.deprecation import _deprecated from pytorch_lightning.utilities.distributed import rank_zero_warn -class Metric(__Metric): +class Metric(_Metric): r""" .. deprecated:: Use :class:`torchmetrics.Metric`. Will be removed in v1.5.0. @@ -44,15 +45,12 @@ def __init__( ) -class MetricCollection(__MetricCollection): +class MetricCollection(_MetricCollection): """ .. deprecated:: Use :class:`torchmetrics.MetricCollection`. Will be removed in v1.5.0. """ + @_deprecated(target=_MetricCollection, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): - rank_zero_warn( - "This `MetricCollection` was deprecated since v1.3.0 in favor of `torchmetrics.MetricCollection`." - " It will be removed in v1.5.0", DeprecationWarning - ) - super().__init__(metrics=metrics) + pass diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index fc033fcd16759..8b0259694ef4c 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.explained_variance import ( _explained_variance_compute, _explained_variance_update, ) -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index ca184daf736b8..484ccbe83284e 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.mean_absolute_error import ( _mean_absolute_error_compute, _mean_absolute_error_update, ) -from pytorch_lightning.metrics.metric import Metric class MeanAbsoluteError(Metric): diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index 09f275ded8638..c26371514e7cd 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.mean_squared_error import ( _mean_squared_error_compute, _mean_squared_error_update, ) -from pytorch_lightning.metrics.metric import Metric class MeanSquaredError(Metric): diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index 18105e687b0b1..caaf09a3663ff 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.mean_squared_log_error import ( _mean_squared_log_error_compute, _mean_squared_log_error_update, ) -from pytorch_lightning.metrics.metric import Metric class MeanSquaredLogError(Metric): diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index 8a38bf515ebca..746ff1e52d574 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -14,10 +14,10 @@ from typing import Any, Optional, Sequence, Tuple, Union import torch +from torchmetrics import Metric from pytorch_lightning import utilities from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update -from pytorch_lightning.metrics.metric import Metric class PSNR(Metric): diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index 40d9d24711375..8156b8bc72d48 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.r2score import _r2score_compute, _r2score_update -from pytorch_lightning.metrics.metric import Metric class R2Score(Metric): diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py index 09b55fb2bb456..a3bbab938ffad 100644 --- a/pytorch_lightning/metrics/regression/ssim.py +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -14,9 +14,9 @@ from typing import Any, Optional, Sequence import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.ssim import _ssim_compute, _ssim_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/retrieval/retrieval_metric.py b/pytorch_lightning/metrics/retrieval/retrieval_metric.py index 29f02555dad69..6f9088d00083c 100644 --- a/pytorch_lightning/metrics/retrieval/retrieval_metric.py +++ b/pytorch_lightning/metrics/retrieval/retrieval_metric.py @@ -2,8 +2,8 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric -from pytorch_lightning.metrics import Metric from pytorch_lightning.metrics.utils import get_group_indexes #: get_group_indexes is used to group predictions belonging to the same query diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index e5ec25a9ad556..5cbb8a877948a 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -24,20 +24,20 @@ from torchmetrics.utilities.distributed import class_reduce as __class_reduce from torchmetrics.utilities.distributed import reduce as __reduce -from pytorch_lightning.utilities.deprecation import deprecated_func +from pytorch_lightning.utilities.deprecation import _deprecated -@deprecated_func(target_func=__dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0") +@_deprecated(target=__dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_cat(x): pass -@deprecated_func(target_func=__dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0") +@_deprecated(target=__dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_sum(x): pass -@deprecated_func(target_func=__dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0") +@_deprecated(target=__dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_mean(x): pass @@ -71,7 +71,7 @@ def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]: return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()] -@deprecated_func(target_func=__to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0") +@_deprecated(target=__to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0") def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: """ .. deprecated:: @@ -79,7 +79,7 @@ def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> """ -@deprecated_func(target_func=__select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0") +@_deprecated(target=__select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0") def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: """ .. deprecated:: @@ -87,7 +87,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch """ -@deprecated_func(target_func=__to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0") +@_deprecated(target=__to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0") def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ .. deprecated:: @@ -95,7 +95,7 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ -@deprecated_func(target_func=__get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0") +@_deprecated(target=__get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0") def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: """ .. deprecated:: @@ -103,7 +103,7 @@ def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optio """ -@deprecated_func(target_func=__reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") +@_deprecated(target=__reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ .. deprecated:: @@ -111,7 +111,7 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ -@deprecated_func(target_func=__class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") +@_deprecated(target=__class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") def class_reduce( num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" ) -> torch.Tensor: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 82f328a927485..554f1d3faf9ed 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -15,8 +15,7 @@ from typing import Any import torch - -from pytorch_lightning.metrics.metric import Metric +from torchmetrics import Metric class MetricsHolder: diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py index 336d833e229cd..ba0d0483ee8bb 100644 --- a/pytorch_lightning/utilities/deprecation.py +++ b/pytorch_lightning/utilities/deprecation.py @@ -34,22 +34,31 @@ def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]] return name_type_default -def deprecated_func(target_func: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable: +def _deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable: + """ + Decorate a function or class ``__init__`` with warning message + and pass all arguments directly to the target class/method. + """ def inner_function(func): @wraps(func) def wrapper(*args, **kwargs): - target_func_str = f'{target_func.__module__}.{target_func.__name__}' + is_class = inspect.isclass(target) + target_func = target.__init__ if is_class else target + target_str = f'{target.__module__}.{target.__name__}' + func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__ rank_zero_warn( - f"This `{func.__name__}` was deprecated since v{ver_deprecate} in favor of `{target_func_str}`." + f"This `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." f" It will be removed in v{ver_remove}.", DeprecationWarning ) + if args: # in case any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_func_arguments_and_types(func)] # convert args to kwargs kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + target_args = [arg[0] for arg in get_func_arguments_and_types(target_func)] assert all(arg in target_args for arg in kwargs), \ "Failed mapping, arguments missing in target func: %s" % [arg not in target_args for arg in kwargs] diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 9d31688d9bcc0..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -15,10 +15,10 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +from torchmetrics import Metric import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result -from pytorch_lightning.metrics import Metric from tests.helpers.runif import RunIf diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index fcbeaa2e78819..3a526ee7fe470 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -16,10 +16,9 @@ import pytest import torch +from pytorch_lightning.metrics import Accuracy, MetricCollection from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot -from pytorch_lightning.metrics import MetricCollection, Accuracy - def test_v1_5_0_metrics_utils(): x = torch.tensor([1, 2, 3]) @@ -42,8 +41,8 @@ def test_v1_5_0_metrics_collection(): target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) with pytest.deprecated_call( - match="This `MetricCollection` was deprecated since v1.3.0 in favor of `torchmetrics.MetricCollection`." - " It will be removed in v1.5.0" + match="This `MetricCollection` was deprecated since v1.3.0 in favor" + " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0" ): metrics = MetricCollection([Accuracy()]) assert metrics(preds, target) == {'Accuracy': torch.Tensor([0.1250])[0]} diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index f13c1ebe26d3e..c9e5467414832 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -5,9 +5,10 @@ import pytest import torch from sklearn.metrics import precision_score, recall_score +from torchmetrics import Metric from torchmetrics.classification.checks import _input_format_classification -from pytorch_lightning.metrics import Metric, Precision, Recall +from pytorch_lightning.metrics import Precision, Recall from pytorch_lightning.metrics.functional import precision, precision_recall, recall from tests.metrics.classification.inputs import _input_binary, _input_binary_prob from tests.metrics.classification.inputs import _input_multiclass as _input_mcls diff --git a/tests/metrics/retrieval/test_map.py b/tests/metrics/retrieval/test_map.py index aa6eeb6424a33..fe43f19b20eb6 100644 --- a/tests/metrics/retrieval/test_map.py +++ b/tests/metrics/retrieval/test_map.py @@ -6,9 +6,9 @@ import pytest import torch from sklearn.metrics import average_precision_score as sk_average_precision +from torchmetrics import Metric from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 895305fa9da7e..93dd213dd7a89 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,7 +1,7 @@ import torch +from torchmetrics import Metric from pytorch_lightning import Trainer -from pytorch_lightning.metrics import Metric, MetricCollection from tests.helpers.boring_model import BoringModel diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4bd6608ce3fcf..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -8,8 +8,7 @@ import pytest import torch from torch.multiprocessing import Pool, set_start_method - -from pytorch_lightning.metrics import Metric +from torchmetrics import Metric try: set_start_method("spawn") diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index 730838686a060..f20bea32993ff 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -1,13 +1,13 @@ import pytest -from pytorch_lightning.utilities.deprecation import deprecated_func +from pytorch_lightning.utilities.deprecation import _deprecated def my_sum(a, b=3): return a + b -@deprecated_func(target_func=my_sum, ver_deprecate="0.1", ver_remove="0.5") +@_deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") def dep_sum(a, b): pass From b5c4d21039b1ad300abe24c6cf3c2deeefc5bf0e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 10:36:21 +0100 Subject: [PATCH 05/11] count --- pytorch_lightning/utilities/deprecation.py | 19 +++++++++++-------- tests/utilities/test_deprecation.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py index ba0d0483ee8bb..db4ef039c3580 100644 --- a/pytorch_lightning/utilities/deprecation.py +++ b/pytorch_lightning/utilities/deprecation.py @@ -43,15 +43,18 @@ def _deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") def inner_function(func): @wraps(func) - def wrapper(*args, **kwargs): + def wrapped_fn(*args, **kwargs): is_class = inspect.isclass(target) target_func = target.__init__ if is_class else target - target_str = f'{target.__module__}.{target.__name__}' - func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__ - rank_zero_warn( - f"This `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." - f" It will be removed in v{ver_remove}.", DeprecationWarning - ) + # warn user only once in lifetime + if not getattr(inner_function, 'warned', False): + target_str = f'{target.__module__}.{target.__name__}' + func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__ + rank_zero_warn( + f"This `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." + f" It will be removed in v{ver_remove}.", DeprecationWarning + ) + inner_function.warned = True if args: # in case any args passed move them to kwargs # parse only the argument names @@ -65,6 +68,6 @@ def wrapper(*args, **kwargs): # all args were already moved to kwargs return target_func(**kwargs) - return wrapper + return wrapped_fn return inner_function diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index f20bea32993ff..ee28a90f9e265 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -12,9 +12,24 @@ def dep_sum(a, b): pass +@_deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") +def dep2_sum(a, b): + pass + + def test_deprecated_func(): with pytest.deprecated_call( match='This `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' ' It will be removed in v0.5.' ): assert dep_sum(2, b=5) == 7 + + with pytest.warns(None) as record: + assert dep_sum(2, b=5) == 7 + assert len(record) == 0 + + with pytest.deprecated_call( + match='This `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' + ' It will be removed in v0.5.' + ): + assert dep2_sum(2) == 5 From 8041708094abff60459ea03e60ba91c99ce6ed57 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 10:43:23 +0100 Subject: [PATCH 06/11] flake8 --- tests/deprecated_api/test_remove_1-5_metrics.py | 4 ++-- tests/metrics/test_metric_lightning.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index 3a526ee7fe470..ac91d8acab8a0 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -41,8 +41,8 @@ def test_v1_5_0_metrics_collection(): target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) with pytest.deprecated_call( - match="This `MetricCollection` was deprecated since v1.3.0 in favor" - " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0" + match="This `MetricCollection` was deprecated since v1.3.0 in favor" + " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0" ): metrics = MetricCollection([Accuracy()]) assert metrics(preds, target) == {'Accuracy': torch.Tensor([0.1250])[0]} diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 93dd213dd7a89..2e040a881d49f 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,5 +1,5 @@ import torch -from torchmetrics import Metric +from torchmetrics import Metric, MetricCollection from pytorch_lightning import Trainer from tests.helpers.boring_model import BoringModel From 866826a665a623841aa81c17fdd8db92537513e9 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 12:57:00 +0100 Subject: [PATCH 07/11] Apply suggestions from code review --- tests/utilities/test_deprecation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index ee28a90f9e265..25e481d789194 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -24,10 +24,12 @@ def test_deprecated_func(): ): assert dep_sum(2, b=5) == 7 + # check that the warning is raised only once per function with pytest.warns(None) as record: assert dep_sum(2, b=5) == 7 assert len(record) == 0 + # and does not affect other functions with pytest.deprecated_call( match='This `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' ' It will be removed in v0.5.' From 82f53c4feb41fc1cb7ecfe937eac07abcc15eda4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 13:02:38 +0100 Subject: [PATCH 08/11] rename --- pytorch_lightning/metrics/metric.py | 4 ++-- pytorch_lightning/metrics/utils.py | 20 +++++++++---------- pytorch_lightning/utilities/deprecation.py | 4 ++-- .../deprecated_api/test_remove_1-5_metrics.py | 2 +- tests/utilities/test_deprecation.py | 6 +++--- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index d4d621b158303..918c92049846e 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -16,7 +16,7 @@ from torchmetrics import Metric as _Metric from torchmetrics.collections import MetricCollection as _MetricCollection -from pytorch_lightning.utilities.deprecation import _deprecated +from pytorch_lightning.utilities.deprecation import deprecated from pytorch_lightning.utilities.distributed import rank_zero_warn @@ -51,6 +51,6 @@ class MetricCollection(_MetricCollection): Use :class:`torchmetrics.MetricCollection`. Will be removed in v1.5.0. """ - @_deprecated(target=_MetricCollection, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated(target=_MetricCollection, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): pass diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 5cbb8a877948a..d0e86d50345e5 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -24,20 +24,20 @@ from torchmetrics.utilities.distributed import class_reduce as __class_reduce from torchmetrics.utilities.distributed import reduce as __reduce -from pytorch_lightning.utilities.deprecation import _deprecated +from pytorch_lightning.utilities.deprecation import deprecated -@_deprecated(target=__dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=__dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_cat(x): pass -@_deprecated(target=__dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=__dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_sum(x): pass -@_deprecated(target=__dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=__dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_mean(x): pass @@ -71,7 +71,7 @@ def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]: return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()] -@_deprecated(target=__to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=__to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0") def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: """ .. deprecated:: @@ -79,7 +79,7 @@ def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> """ -@_deprecated(target=__select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=__select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0") def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: """ .. deprecated:: @@ -87,7 +87,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch """ -@_deprecated(target=__to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=__to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0") def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ .. deprecated:: @@ -95,7 +95,7 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ -@_deprecated(target=__get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=__get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0") def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: """ .. deprecated:: @@ -103,7 +103,7 @@ def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optio """ -@_deprecated(target=__reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=__reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ .. deprecated:: @@ -111,7 +111,7 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ -@_deprecated(target=__class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=__class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") def class_reduce( num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" ) -> torch.Tensor: diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py index db4ef039c3580..3e2034c6a0453 100644 --- a/pytorch_lightning/utilities/deprecation.py +++ b/pytorch_lightning/utilities/deprecation.py @@ -34,7 +34,7 @@ def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]] return name_type_default -def _deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable: +def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable: """ Decorate a function or class ``__init__`` with warning message and pass all arguments directly to the target class/method. @@ -51,7 +51,7 @@ def wrapped_fn(*args, **kwargs): target_str = f'{target.__module__}.{target.__name__}' func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__ rank_zero_warn( - f"This `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." + f"The `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." f" It will be removed in v{ver_remove}.", DeprecationWarning ) inner_function.warned = True diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index ac91d8acab8a0..7c8c9ad296416 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -41,7 +41,7 @@ def test_v1_5_0_metrics_collection(): target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) with pytest.deprecated_call( - match="This `MetricCollection` was deprecated since v1.3.0 in favor" + match="The `MetricCollection` was deprecated since v1.3.0 in favor" " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0" ): metrics = MetricCollection([Accuracy()]) diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index 25e481d789194..03d07d4147b61 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -1,18 +1,18 @@ import pytest -from pytorch_lightning.utilities.deprecation import _deprecated +from pytorch_lightning.utilities.deprecation import deprecated def my_sum(a, b=3): return a + b -@_deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") +@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") def dep_sum(a, b): pass -@_deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") +@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") def dep2_sum(a, b): pass From 1b2b95c2a005d5e1def665c12da328a076be9588 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 13:04:58 +0100 Subject: [PATCH 09/11] call --- tests/utilities/test_deprecation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index 03d07d4147b61..ef8456491b78b 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -1,6 +1,7 @@ import pytest from pytorch_lightning.utilities.deprecation import deprecated +from tests.helpers.utils import no_warning_call def my_sum(a, b=3): @@ -25,9 +26,8 @@ def test_deprecated_func(): assert dep_sum(2, b=5) == 7 # check that the warning is raised only once per function - with pytest.warns(None) as record: + with no_warning_call(DeprecationWarning): assert dep_sum(2, b=5) == 7 - assert len(record) == 0 # and does not affect other functions with pytest.deprecated_call( From 833aef7b464e2ed44ca11ca7961ba57cd316d435 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 13:21:25 +0100 Subject: [PATCH 10/11] typo --- tests/utilities/test_deprecation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index ef8456491b78b..7c653c07ad168 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -20,7 +20,7 @@ def dep2_sum(a, b): def test_deprecated_func(): with pytest.deprecated_call( - match='This `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' + match='The `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' ' It will be removed in v0.5.' ): assert dep_sum(2, b=5) == 7 @@ -31,7 +31,7 @@ def test_deprecated_func(): # and does not affect other functions with pytest.deprecated_call( - match='This `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' + match='The `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' ' It will be removed in v0.5.' ): assert dep2_sum(2) == 5 From 598a34b348750e7488902f5d0e0b937bb5d7140b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 15:17:39 +0100 Subject: [PATCH 11/11] __ --- pytorch_lightning/metrics/utils.py | 36 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index d0e86d50345e5..b758e317c6c8d 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -14,30 +14,30 @@ from typing import List, Optional import torch -from torchmetrics.utilities.data import dim_zero_cat as __dim_zero_cat -from torchmetrics.utilities.data import dim_zero_mean as __dim_zero_mean -from torchmetrics.utilities.data import dim_zero_sum as __dim_zero_sum -from torchmetrics.utilities.data import get_num_classes as __get_num_classes -from torchmetrics.utilities.data import select_topk as __select_topk -from torchmetrics.utilities.data import to_categorical as __to_categorical -from torchmetrics.utilities.data import to_onehot as __to_onehot -from torchmetrics.utilities.distributed import class_reduce as __class_reduce -from torchmetrics.utilities.distributed import reduce as __reduce +from torchmetrics.utilities.data import dim_zero_cat as _dim_zero_cat +from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean +from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum +from torchmetrics.utilities.data import get_num_classes as _get_num_classes +from torchmetrics.utilities.data import select_topk as _select_topk +from torchmetrics.utilities.data import to_categorical as _to_categorical +from torchmetrics.utilities.data import to_onehot as _to_onehot +from torchmetrics.utilities.distributed import class_reduce as _class_reduce +from torchmetrics.utilities.distributed import reduce as _reduce from pytorch_lightning.utilities.deprecation import deprecated -@deprecated(target=__dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=_dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_cat(x): pass -@deprecated(target=__dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=_dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_sum(x): pass -@deprecated(target=__dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=_dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_mean(x): pass @@ -71,7 +71,7 @@ def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]: return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()] -@deprecated(target=__to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=_to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0") def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: """ .. deprecated:: @@ -79,7 +79,7 @@ def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> """ -@deprecated(target=__select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=_select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0") def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: """ .. deprecated:: @@ -87,7 +87,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch """ -@deprecated(target=__to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=_to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0") def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ .. deprecated:: @@ -95,7 +95,7 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ -@deprecated(target=__get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=_get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0") def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: """ .. deprecated:: @@ -103,7 +103,7 @@ def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optio """ -@deprecated(target=__reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ .. deprecated:: @@ -111,7 +111,7 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ -@deprecated(target=__class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated(target=_class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") def class_reduce( num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" ) -> torch.Tensor: