1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414"""Test deprecated functionality which will be removed in vX.Y.Z"""
15- import sys
1615from argparse import ArgumentParser
1716from unittest import mock
1817
2120
2221from pytorch_lightning import LightningModule , Trainer
2322from pytorch_lightning .callbacks import EarlyStopping , ModelCheckpoint
24- from pytorch_lightning .metrics .functional .classification import auc
2523from pytorch_lightning .profiler .profilers import PassThroughProfiler , SimpleProfiler
2624from pytorch_lightning .utilities .exceptions import MisconfigurationException
27- from tests .base import EvalModelTemplate
2825
2926
3027def test_tbd_remove_in_v1_3_0 (tmpdir ):
@@ -52,27 +49,27 @@ def __init__(self, hparams):
5249
5350
5451def test_tbd_remove_in_v1_3_0_metrics ():
52+ from pytorch_lightning .metrics .functional .classification import to_onehot
5553 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
56- from pytorch_lightning .metrics .functional .classification import to_onehot
5754 to_onehot (torch .tensor ([1 , 2 , 3 ]))
5855
56+ from pytorch_lightning .metrics .functional .classification import to_categorical
5957 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
60- from pytorch_lightning .metrics .functional .classification import to_categorical
6158 to_categorical (torch .tensor ([[0.2 , 0.5 ], [0.9 , 0.1 ]]))
6259
60+ from pytorch_lightning .metrics .functional .classification import get_num_classes
6361 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
64- from pytorch_lightning .metrics .functional .classification import get_num_classes
6562 get_num_classes (pred = torch .tensor ([0 , 1 ]), target = torch .tensor ([1 , 1 ]))
6663
6764 x_binary = torch .tensor ([0 , 1 , 2 , 3 ])
6865 y_binary = torch .tensor ([0 , 1 , 2 , 3 ])
6966
67+ from pytorch_lightning .metrics .functional .classification import roc
7068 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
71- from pytorch_lightning .metrics .functional .classification import roc
7269 roc (pred = x_binary , target = y_binary )
7370
71+ from pytorch_lightning .metrics .functional .classification import _roc
7472 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
75- from pytorch_lightning .metrics .functional .classification import _roc
7673 _roc (pred = x_binary , target = y_binary )
7774
7875 x_multy = torch .tensor ([[0.85 , 0.05 , 0.05 , 0.05 ],
@@ -81,64 +78,40 @@ def test_tbd_remove_in_v1_3_0_metrics():
8178 [0.05 , 0.05 , 0.05 , 0.85 ]])
8279 y_multy = torch .tensor ([0 , 1 , 3 , 2 ])
8380
81+ from pytorch_lightning .metrics .functional .classification import multiclass_roc
8482 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
85- from pytorch_lightning .metrics .functional .classification import multiclass_roc
8683 multiclass_roc (pred = x_multy , target = y_multy )
8784
85+ from pytorch_lightning .metrics .functional .classification import average_precision
8886 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
89- from pytorch_lightning .metrics .functional .classification import average_precision
9087 average_precision (pred = x_binary , target = y_binary )
9188
89+ from pytorch_lightning .metrics .functional .classification import precision_recall_curve
9290 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
93- from pytorch_lightning .metrics .functional .classification import precision_recall_curve
9491 precision_recall_curve (pred = x_binary , target = y_binary )
9592
93+ from pytorch_lightning .metrics .functional .classification import multiclass_precision_recall_curve
9694 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
97- from pytorch_lightning .metrics .functional .classification import multiclass_precision_recall_curve
9895 multiclass_precision_recall_curve (pred = x_multy , target = y_multy )
9996
97+ from pytorch_lightning .metrics .functional .reduction import reduce
10098 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
101- from pytorch_lightning .metrics .functional .reduction import reduce
10299 reduce (torch .tensor ([0 , 1 , 1 , 0 ]), 'sum' )
103100
101+ from pytorch_lightning .metrics .functional .reduction import class_reduce
104102 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
105- from pytorch_lightning .metrics .functional .reduction import class_reduce
106103 class_reduce (torch .randint (1 , 10 , (50 ,)).float (),
107104 torch .randint (10 , 20 , (50 ,)).float (),
108105 torch .randint (1 , 100 , (50 ,)).float ())
109106
110107
111- def test_tbd_remove_in_v1_2_0 ():
112- with pytest .deprecated_call (match = 'will be removed in v1.2' ):
113- checkpoint_cb = ModelCheckpoint (filepath = '.' )
114-
115- with pytest .deprecated_call (match = 'will be removed in v1.2' ):
116- checkpoint_cb = ModelCheckpoint ('.' )
117-
118- with pytest .raises (MisconfigurationException , match = 'inputs which are not feasible' ):
119- checkpoint_cb = ModelCheckpoint (filepath = '.' , dirpath = '.' )
120-
121-
122- def test_tbd_remove_in_v1_2_0_metrics ():
123- from pytorch_lightning .metrics .classification import Fbeta
124- from pytorch_lightning .metrics .functional .classification import f1_score , fbeta_score
125-
126- with pytest .deprecated_call (match = 'will be removed in v1.2' ):
127- Fbeta (2 )
128-
129- with pytest .deprecated_call (match = 'will be removed in v1.2' ):
130- fbeta_score (torch .tensor ([0 , 1 , 2 , 3 ]), torch .tensor ([0 , 1 , 2 , 1 ]), 0.2 )
131-
132- with pytest .deprecated_call (match = 'will be removed in v1.2' ):
133- f1_score (torch .tensor ([0 , 1 , 0 , 1 ]), torch .tensor ([0 , 1 , 0 , 0 ]))
134-
135-
136108# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
137109@pytest .mark .parametrize (['profiler' , 'expected' ], [
138110 (True , SimpleProfiler ),
139111 (False , PassThroughProfiler ),
140112])
141113def test_trainer_profiler_remove_in_v1_3_0 (profiler , expected ):
114+ # remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
142115 with pytest .deprecated_call (match = 'will be removed in v1.3' ):
143116 trainer = Trainer (profiler = profiler )
144117 assert isinstance (trainer .profiler , expected )
@@ -162,47 +135,3 @@ def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, ex
162135 assert getattr (args , "profiler" ) == expected_parsed_arg
163136 trainer = Trainer .from_argparse_args (args )
164137 assert isinstance (trainer .profiler , expected_profiler )
165-
166-
167- def _soft_unimport_module (str_module ):
168- # once the module is imported e.g with parsing with pytest it lives in memory
169- if str_module in sys .modules :
170- del sys .modules [str_module ]
171-
172-
173- class ModelVer0_6 (EvalModelTemplate ):
174-
175- # todo: this shall not be needed while evaluate asks for dataloader explicitly
176- def val_dataloader (self ):
177- return self .dataloader (train = False )
178-
179- def validation_step (self , batch , batch_idx , * args , ** kwargs ):
180- return {'val_loss' : torch .tensor (0.6 )}
181-
182- def validation_end (self , outputs ):
183- return {'val_loss' : torch .tensor (0.6 )}
184-
185- def test_dataloader (self ):
186- return self .dataloader (train = False )
187-
188- def test_end (self , outputs ):
189- return {'test_loss' : torch .tensor (0.6 )}
190-
191-
192- class ModelVer0_7 (EvalModelTemplate ):
193-
194- # todo: this shall not be needed while evaluate asks for dataloader explicitly
195- def val_dataloader (self ):
196- return self .dataloader (train = False )
197-
198- def validation_step (self , batch , batch_idx , * args , ** kwargs ):
199- return {'val_loss' : torch .tensor (0.7 )}
200-
201- def validation_end (self , outputs ):
202- return {'val_loss' : torch .tensor (0.7 )}
203-
204- def test_dataloader (self ):
205- return self .dataloader (train = False )
206-
207- def test_end (self , outputs ):
208- return {'test_loss' : torch .tensor (0.7 )}
0 commit comments