|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | """Test deprecated functionality which will be removed in vX.Y.Z""" |
15 | | -from argparse import ArgumentParser |
16 | | -from unittest import mock |
17 | 15 |
|
18 | 16 | import pytest |
19 | 17 | import torch |
20 | 18 |
|
21 | 19 | from pytorch_lightning import LightningModule, Trainer |
22 | 20 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint |
23 | | -from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler |
24 | 21 |
|
25 | 22 |
|
26 | 23 | def test_v1_3_0_deprecated_arguments(tmpdir): |
@@ -111,38 +108,6 @@ def test_v1_3_0_deprecated_metrics(): |
111 | 108 | ) |
112 | 109 |
|
113 | 110 |
|
114 | | -# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py |
115 | | -@pytest.mark.parametrize(['profiler', 'expected'], [ |
116 | | - (True, SimpleProfiler), |
117 | | - (False, PassThroughProfiler), |
118 | | -]) |
119 | | -def test_trainer_profiler_remove_in_v1_3_0(profiler, expected): |
120 | | - # remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py |
121 | | - with pytest.deprecated_call(match='will be removed in v1.3'): |
122 | | - trainer = Trainer(profiler=profiler) |
123 | | - assert isinstance(trainer.profiler, expected) |
124 | | - |
125 | | - |
126 | | -@pytest.mark.parametrize( |
127 | | - ['cli_args', 'expected_parsed_arg', 'expected_profiler'], |
128 | | - [ |
129 | | - ('--profiler', True, SimpleProfiler), |
130 | | - ('--profiler True', True, SimpleProfiler), |
131 | | - ('--profiler False', False, PassThroughProfiler), |
132 | | - ], |
133 | | -) |
134 | | -def test_v1_3_0_trainer_cli_profiler(cli_args, expected_parsed_arg, expected_profiler): |
135 | | - cli_args = cli_args.split(' ') |
136 | | - with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): |
137 | | - parser = ArgumentParser(add_help=False) |
138 | | - parser = Trainer.add_argparse_args(parent_parser=parser) |
139 | | - args = Trainer.parse_argparser(parser) |
140 | | - |
141 | | - assert getattr(args, "profiler") == expected_parsed_arg |
142 | | - trainer = Trainer.from_argparse_args(args) |
143 | | - assert isinstance(trainer.profiler, expected_profiler) |
144 | | - |
145 | | - |
146 | 111 | def test_trainer_enable_pl_optimizer(tmpdir): |
147 | 112 | with pytest.deprecated_call(match='will be removed in v1.3'): |
148 | 113 | Trainer(enable_pl_optimizer=True) |
0 commit comments