Skip to content

Commit b50ad9e

Browse files
Bordarohitgr7
andauthored
split tests for deprecated api (#5071)
* imports * imports * flake8 Co-authored-by: Rohit Gupta <[email protected]>
1 parent 3100b78 commit b50ad9e

File tree

3 files changed

+78
-83
lines changed

3 files changed

+78
-83
lines changed

tests/deprecated_api/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Test deprecated functionality which will be removed in vX.Y.Z"""
15+
import sys
16+
17+
18+
def _soft_unimport_module(str_module):
19+
# once the module is imported e.g with parsing with pytest it lives in memory
20+
if str_module in sys.modules:
21+
del sys.modules[str_module]
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Test deprecated functionality which will be removed in vX.Y.Z"""
15+
16+
import pytest
17+
import torch
18+
19+
from pytorch_lightning.callbacks import ModelCheckpoint
20+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
21+
22+
23+
def test_tbd_remove_in_v1_2_0():
24+
with pytest.deprecated_call(match='will be removed in v1.2'):
25+
ModelCheckpoint(filepath='..')
26+
27+
with pytest.deprecated_call(match='will be removed in v1.2'):
28+
ModelCheckpoint('..')
29+
30+
with pytest.raises(MisconfigurationException, match='inputs which are not feasible'):
31+
ModelCheckpoint(filepath='..', dirpath='.')
32+
33+
34+
def test_tbd_remove_in_v1_2_0_metrics():
35+
from pytorch_lightning.metrics.classification import Fbeta
36+
from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score
37+
38+
with pytest.deprecated_call(match='will be removed in v1.2'):
39+
Fbeta(2)
40+
41+
with pytest.deprecated_call(match='will be removed in v1.2'):
42+
fbeta_score(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 1]), 0.2)
43+
44+
with pytest.deprecated_call(match='will be removed in v1.2'):
45+
f1_score(torch.tensor([0, 1, 0, 1]), torch.tensor([0, 1, 0, 0]))

tests/test_deprecated.py renamed to tests/deprecated_api/test_remove_1-3.py

Lines changed: 12 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Test deprecated functionality which will be removed in vX.Y.Z"""
15-
import sys
1615
from argparse import ArgumentParser
1716
from unittest import mock
1817

@@ -21,10 +20,8 @@
2120

2221
from pytorch_lightning import LightningModule, Trainer
2322
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
24-
from pytorch_lightning.metrics.functional.classification import auc
2523
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler
2624
from pytorch_lightning.utilities.exceptions import MisconfigurationException
27-
from tests.base import EvalModelTemplate
2825

2926

3027
def test_tbd_remove_in_v1_3_0(tmpdir):
@@ -52,27 +49,27 @@ def __init__(self, hparams):
5249

5350

5451
def test_tbd_remove_in_v1_3_0_metrics():
52+
from pytorch_lightning.metrics.functional.classification import to_onehot
5553
with pytest.deprecated_call(match='will be removed in v1.3'):
56-
from pytorch_lightning.metrics.functional.classification import to_onehot
5754
to_onehot(torch.tensor([1, 2, 3]))
5855

56+
from pytorch_lightning.metrics.functional.classification import to_categorical
5957
with pytest.deprecated_call(match='will be removed in v1.3'):
60-
from pytorch_lightning.metrics.functional.classification import to_categorical
6158
to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]]))
6259

60+
from pytorch_lightning.metrics.functional.classification import get_num_classes
6361
with pytest.deprecated_call(match='will be removed in v1.3'):
64-
from pytorch_lightning.metrics.functional.classification import get_num_classes
6562
get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1]))
6663

6764
x_binary = torch.tensor([0, 1, 2, 3])
6865
y_binary = torch.tensor([0, 1, 2, 3])
6966

67+
from pytorch_lightning.metrics.functional.classification import roc
7068
with pytest.deprecated_call(match='will be removed in v1.3'):
71-
from pytorch_lightning.metrics.functional.classification import roc
7269
roc(pred=x_binary, target=y_binary)
7370

71+
from pytorch_lightning.metrics.functional.classification import _roc
7472
with pytest.deprecated_call(match='will be removed in v1.3'):
75-
from pytorch_lightning.metrics.functional.classification import _roc
7673
_roc(pred=x_binary, target=y_binary)
7774

7875
x_multy = torch.tensor([[0.85, 0.05, 0.05, 0.05],
@@ -81,64 +78,40 @@ def test_tbd_remove_in_v1_3_0_metrics():
8178
[0.05, 0.05, 0.05, 0.85]])
8279
y_multy = torch.tensor([0, 1, 3, 2])
8380

81+
from pytorch_lightning.metrics.functional.classification import multiclass_roc
8482
with pytest.deprecated_call(match='will be removed in v1.3'):
85-
from pytorch_lightning.metrics.functional.classification import multiclass_roc
8683
multiclass_roc(pred=x_multy, target=y_multy)
8784

85+
from pytorch_lightning.metrics.functional.classification import average_precision
8886
with pytest.deprecated_call(match='will be removed in v1.3'):
89-
from pytorch_lightning.metrics.functional.classification import average_precision
9087
average_precision(pred=x_binary, target=y_binary)
9188

89+
from pytorch_lightning.metrics.functional.classification import precision_recall_curve
9290
with pytest.deprecated_call(match='will be removed in v1.3'):
93-
from pytorch_lightning.metrics.functional.classification import precision_recall_curve
9491
precision_recall_curve(pred=x_binary, target=y_binary)
9592

93+
from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve
9694
with pytest.deprecated_call(match='will be removed in v1.3'):
97-
from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve
9895
multiclass_precision_recall_curve(pred=x_multy, target=y_multy)
9996

97+
from pytorch_lightning.metrics.functional.reduction import reduce
10098
with pytest.deprecated_call(match='will be removed in v1.3'):
101-
from pytorch_lightning.metrics.functional.reduction import reduce
10299
reduce(torch.tensor([0, 1, 1, 0]), 'sum')
103100

101+
from pytorch_lightning.metrics.functional.reduction import class_reduce
104102
with pytest.deprecated_call(match='will be removed in v1.3'):
105-
from pytorch_lightning.metrics.functional.reduction import class_reduce
106103
class_reduce(torch.randint(1, 10, (50,)).float(),
107104
torch.randint(10, 20, (50,)).float(),
108105
torch.randint(1, 100, (50,)).float())
109106

110107

111-
def test_tbd_remove_in_v1_2_0():
112-
with pytest.deprecated_call(match='will be removed in v1.2'):
113-
checkpoint_cb = ModelCheckpoint(filepath='.')
114-
115-
with pytest.deprecated_call(match='will be removed in v1.2'):
116-
checkpoint_cb = ModelCheckpoint('.')
117-
118-
with pytest.raises(MisconfigurationException, match='inputs which are not feasible'):
119-
checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.')
120-
121-
122-
def test_tbd_remove_in_v1_2_0_metrics():
123-
from pytorch_lightning.metrics.classification import Fbeta
124-
from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score
125-
126-
with pytest.deprecated_call(match='will be removed in v1.2'):
127-
Fbeta(2)
128-
129-
with pytest.deprecated_call(match='will be removed in v1.2'):
130-
fbeta_score(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 1]), 0.2)
131-
132-
with pytest.deprecated_call(match='will be removed in v1.2'):
133-
f1_score(torch.tensor([0, 1, 0, 1]), torch.tensor([0, 1, 0, 0]))
134-
135-
136108
# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
137109
@pytest.mark.parametrize(['profiler', 'expected'], [
138110
(True, SimpleProfiler),
139111
(False, PassThroughProfiler),
140112
])
141113
def test_trainer_profiler_remove_in_v1_3_0(profiler, expected):
114+
# remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
142115
with pytest.deprecated_call(match='will be removed in v1.3'):
143116
trainer = Trainer(profiler=profiler)
144117
assert isinstance(trainer.profiler, expected)
@@ -162,47 +135,3 @@ def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, ex
162135
assert getattr(args, "profiler") == expected_parsed_arg
163136
trainer = Trainer.from_argparse_args(args)
164137
assert isinstance(trainer.profiler, expected_profiler)
165-
166-
167-
def _soft_unimport_module(str_module):
168-
# once the module is imported e.g with parsing with pytest it lives in memory
169-
if str_module in sys.modules:
170-
del sys.modules[str_module]
171-
172-
173-
class ModelVer0_6(EvalModelTemplate):
174-
175-
# todo: this shall not be needed while evaluate asks for dataloader explicitly
176-
def val_dataloader(self):
177-
return self.dataloader(train=False)
178-
179-
def validation_step(self, batch, batch_idx, *args, **kwargs):
180-
return {'val_loss': torch.tensor(0.6)}
181-
182-
def validation_end(self, outputs):
183-
return {'val_loss': torch.tensor(0.6)}
184-
185-
def test_dataloader(self):
186-
return self.dataloader(train=False)
187-
188-
def test_end(self, outputs):
189-
return {'test_loss': torch.tensor(0.6)}
190-
191-
192-
class ModelVer0_7(EvalModelTemplate):
193-
194-
# todo: this shall not be needed while evaluate asks for dataloader explicitly
195-
def val_dataloader(self):
196-
return self.dataloader(train=False)
197-
198-
def validation_step(self, batch, batch_idx, *args, **kwargs):
199-
return {'val_loss': torch.tensor(0.7)}
200-
201-
def validation_end(self, outputs):
202-
return {'val_loss': torch.tensor(0.7)}
203-
204-
def test_dataloader(self):
205-
return self.dataloader(train=False)
206-
207-
def test_end(self, outputs):
208-
return {'test_loss': torch.tensor(0.7)}

0 commit comments

Comments
 (0)