Skip to content

Commit 297e438

Browse files
authored
fix deprecation wrapper & tests (#6553)
* fix deprecation wrapper & tests * flake8
1 parent 00cd918 commit 297e438

File tree

3 files changed

+85
-24
lines changed

3 files changed

+85
-24
lines changed

pytorch_lightning/utilities/deprecation.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import inspect
1515
from functools import wraps
16-
from typing import Any, Callable, List, Tuple
16+
from typing import Any, Callable, List, Tuple, Optional
1717

1818
from pytorch_lightning.utilities import rank_zero_warn
1919

@@ -34,37 +34,41 @@ def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]]
3434
return name_type_default
3535

3636

37-
def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable:
37+
def deprecated(target: Callable, ver_deprecate: Optional[str] = "", ver_remove: Optional[str] = "") -> Callable:
3838
"""
3939
Decorate a function or class ``__init__`` with warning message
4040
and pass all arguments directly to the target class/method.
41-
"""
41+
"""
4242

43-
def inner_function(func):
43+
def inner_function(base):
4444

45-
@wraps(func)
45+
@wraps(base)
4646
def wrapped_fn(*args, **kwargs):
4747
is_class = inspect.isclass(target)
4848
target_func = target.__init__ if is_class else target
4949
# warn user only once in lifetime
50-
if not getattr(inner_function, 'warned', False):
50+
if not getattr(wrapped_fn, 'warned', False):
5151
target_str = f'{target.__module__}.{target.__name__}'
52-
func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__
52+
base_name = base.__qualname__.split('.')[-2] if is_class else base.__name__
53+
base_str = f'{base.__module__}.{base_name}'
5354
rank_zero_warn(
54-
f"The `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`."
55+
f"`{base_str}` was deprecated since v{ver_deprecate} in favor of `{target_str}`."
5556
f" It will be removed in v{ver_remove}.", DeprecationWarning
5657
)
57-
inner_function.warned = True
58+
wrapped_fn.warned = True
5859

5960
if args: # in case any args passed move them to kwargs
6061
# parse only the argument names
61-
cls_arg_names = [arg[0] for arg in get_func_arguments_and_types(func)]
62+
arg_names = [arg[0] for arg in get_func_arguments_and_types(base)]
6263
# convert args to kwargs
63-
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
64+
kwargs.update({k: v for k, v in zip(arg_names, args)})
65+
# fill by base defaults
66+
base_defaults = {arg[0]: arg[2] for arg in get_func_arguments_and_types(base) if arg[2] != inspect._empty}
67+
kwargs = dict(list(base_defaults.items()) + list(kwargs.items()))
6468

6569
target_args = [arg[0] for arg in get_func_arguments_and_types(target_func)]
6670
assert all(arg in target_args for arg in kwargs), \
67-
"Failed mapping, arguments missing in target func: %s" % [arg not in target_args for arg in kwargs]
71+
"Failed mapping, arguments missing in target base: %s" % [arg not in target_args for arg in kwargs]
6872
# all args were already moved to kwargs
6973
return target_func(**kwargs)
7074

tests/deprecated_api/test_remove_1-5_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def test_v1_5_0_metrics_collection():
4141
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
4242
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
4343
with pytest.deprecated_call(
44-
match="The `MetricCollection` was deprecated since v1.3.0 in favor"
45-
" of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0"
44+
match="`pytorch_lightning.metrics.metric.MetricCollection` was deprecated since v1.3.0 in favor"
45+
" of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0."
4646
):
4747
metrics = MetricCollection([Accuracy()])
4848
assert metrics(preds, target) == {'Accuracy': torch.Tensor([0.1250])[0]}

tests/utilities/test_deprecation.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,91 @@
44
from tests.helpers.utils import no_warning_call
55

66

7-
def my_sum(a, b=3):
7+
def my_sum(a=0, b=3):
8+
return a + b
9+
10+
11+
def my2_sum(a, b):
812
return a + b
913

1014

1115
@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5")
12-
def dep_sum(a, b):
16+
def dep_sum(a, b=5):
1317
pass
1418

1519

16-
@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5")
20+
@deprecated(target=my2_sum, ver_deprecate="0.1", ver_remove="0.5")
1721
def dep2_sum(a, b):
1822
pass
1923

2024

25+
@deprecated(target=my2_sum, ver_deprecate="0.1", ver_remove="0.5")
26+
def dep3_sum(a, b=4):
27+
pass
28+
29+
2130
def test_deprecated_func():
2231
with pytest.deprecated_call(
23-
match='The `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.'
24-
' It will be removed in v0.5.'
32+
match='`tests.utilities.test_deprecation.dep_sum` was deprecated since v0.1 in favor'
33+
' of `tests.utilities.test_deprecation.my_sum`. It will be removed in v0.5.'
2534
):
26-
assert dep_sum(2, b=5) == 7
35+
assert dep_sum(2) == 7
2736

2837
# check that the warning is raised only once per function
2938
with no_warning_call(DeprecationWarning):
30-
assert dep_sum(2, b=5) == 7
39+
assert dep_sum(3) == 8
3140

3241
# and does not affect other functions
3342
with pytest.deprecated_call(
34-
match='The `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.'
35-
' It will be removed in v0.5.'
43+
match='`tests.utilities.test_deprecation.dep3_sum` was deprecated since v0.1 in favor'
44+
' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.'
3645
):
37-
assert dep2_sum(2) == 5
46+
assert dep3_sum(2, 1) == 3
47+
48+
49+
def test_deprecated_func_incomplete():
50+
51+
# missing required argument
52+
with pytest.raises(TypeError, match="missing 1 required positional argument: 'b'"):
53+
dep2_sum(2)
54+
55+
# check that the warning is raised only once per function
56+
with no_warning_call(DeprecationWarning):
57+
assert dep2_sum(2, 1) == 3
58+
59+
# reset the warning
60+
dep2_sum.warned = False
61+
# does not affect other functions
62+
with pytest.deprecated_call(
63+
match='`tests.utilities.test_deprecation.dep2_sum` was deprecated since v0.1 in favor'
64+
' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.'
65+
):
66+
assert dep2_sum(b=2, a=1) == 3
67+
68+
69+
class NewCls:
70+
71+
def __init__(self, c, d="abc"):
72+
self.my_c = c
73+
self.my_d = d
74+
75+
76+
class PastCls:
77+
78+
@deprecated(target=NewCls, ver_deprecate="0.2", ver_remove="0.4")
79+
def __init__(self, c, d="efg"):
80+
pass
81+
82+
83+
def test_deprecated_class():
84+
with pytest.deprecated_call(
85+
match='`tests.utilities.test_deprecation.PastCls` was deprecated since v0.2 in favor'
86+
' of `tests.utilities.test_deprecation.NewCls`. It will be removed in v0.4.'
87+
):
88+
past = PastCls(2)
89+
assert past.my_c == 2
90+
assert past.my_d == "efg"
91+
92+
# check that the warning is raised only once per function
93+
with no_warning_call(DeprecationWarning):
94+
assert PastCls(c=2, d="")

0 commit comments

Comments
 (0)