Skip to content

Commit c50c225

Browse files
ddrevickycarmoccarohitgr7Bordaawaelchli
authored
feature: Allow str arguments in Trainer.profiler (#3656)
* allow trainer's profiler param to have a str value * add tests * update docs * update exception message * Update CHANGELOG * fix pep8 issues * cleanup test code Co-authored-by: Carlos Mocholí <[email protected]> * Add deprecation warning if using bool for profiler * Add deprecation tests and move deprecated tests * Remove bool option to profiler from docs * Deprecate bool args to profiler in CHANGELOG * fixup! Add deprecation warning if using bool for profiler * fixup! Add deprecation tests and move deprecated tests * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> * Implement suggestions, remove whitespace * fixup! Implement suggestions, remove whitespace * Allow bool, str (case insensitive), BaseProfiler * Add info about bool deprecation to trainer * fixup! Add info about bool deprecation to trainer * Move deprecate todo to test_deprecated * Test wrong profiler type, improve error message * fixup! Test wrong profiler type, improve error message * Update pytorch_lightning/trainer/connectors/profiler_connector.py Co-authored-by: Carlos Mocholí <[email protected]> * Apply suggestions from code review * Readd bool to profiler types, test cli profiler arg * Remove extra whitespace in doc Co-authored-by: Adrian Wälchli <[email protected]> * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> * Update deprecation versions Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 48b6de0 commit c50c225

File tree

8 files changed

+115
-19
lines changed

8 files changed

+115
-19
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424
- Added autogenerated helptext to `Trainer.add_argparse_args`. ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344))
2525

2626

27+
- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))
28+
29+
2730
### Changed
2831

2932

@@ -48,6 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4851
- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237))
4952

5053

54+
- Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))
55+
56+
5157
### Removed
5258

5359

pytorch_lightning/profiler/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
Enable simple profiling
2323
-----------------------
2424
25-
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
26-
your `Trainer` object.
25+
If you only wish to profile the standard actions, you can set `profiler="simple"`
26+
when constructing your `Trainer` object.
2727
2828
.. code-block:: python
2929
30-
trainer = Trainer(..., profiler=True)
30+
trainer = Trainer(..., profiler="simple")
3131
3232
The profiler's results will be printed at the completion of a training `fit()`.
3333
@@ -59,6 +59,10 @@
5959
6060
.. code-block:: python
6161
62+
trainer = Trainer(..., profiler="advanced")
63+
64+
or
65+
6266
profiler = AdvancedProfiler()
6367
trainer = Trainer(..., profiler=profiler)
6468

pytorch_lightning/trainer/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,14 +1199,11 @@ def world_size(self):
11991199
# default used by the Trainer
12001200
trainer = Trainer(profiler=None)
12011201
1202-
# to profile standard training events
1203-
trainer = Trainer(profiler=True)
1202+
# to profile standard training events, equivalent to `profiler=SimpleProfiler()`
1203+
trainer = Trainer(profiler="simple")
12041204
1205-
# equivalent to profiler=True
1206-
trainer = Trainer(profiler=SimpleProfiler())
1207-
1208-
# advanced profiler for function-level stats
1209-
trainer = Trainer(profiler=AdvancedProfiler())
1205+
# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()`
1206+
trainer = Trainer(profiler="advanced")
12101207
12111208
progress_bar_refresh_rate
12121209
^^^^^^^^^^^^^^^^^^^^^^^^^

pytorch_lightning/trainer/connectors/profiler_connector.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,40 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
14-
from pytorch_lightning.profiler import PassThroughProfiler, SimpleProfiler
14+
15+
from typing import Union
16+
17+
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler, AdvancedProfiler
18+
from pytorch_lightning.utilities import rank_zero_warn
19+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1520

1621

1722
class ProfilerConnector:
1823

1924
def __init__(self, trainer):
2025
self.trainer = trainer
2126

22-
def on_trainer_init(self, profiler):
23-
# configure profiler
24-
if profiler is True:
25-
profiler = SimpleProfiler()
27+
def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
28+
29+
if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
30+
# TODO: Update exception on removal of bool
31+
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` "
32+
"are valid values for `Trainer`'s `profiler` parameter. "
33+
f"Received {profiler} which is of type {type(profiler)}.")
34+
35+
if isinstance(profiler, bool):
36+
rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
37+
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.",
38+
DeprecationWarning)
39+
if profiler:
40+
profiler = SimpleProfiler()
41+
elif isinstance(profiler, str):
42+
profiler = profiler.lower()
43+
if profiler == "simple":
44+
profiler = SimpleProfiler()
45+
elif profiler == "advanced":
46+
profiler = AdvancedProfiler()
47+
else:
48+
raise ValueError("When passing string value for the `profiler` parameter of"
49+
" `Trainer`, it can only be 'simple' or 'advanced'")
2650
self.trainer.profiler = profiler or PassThroughProfiler()

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(
120120
num_sanity_val_steps: int = 2,
121121
truncated_bptt_steps: Optional[int] = None,
122122
resume_from_checkpoint: Optional[str] = None,
123-
profiler: Optional[Union[BaseProfiler, bool]] = None,
123+
profiler: Optional[Union[BaseProfiler, bool, str]] = None,
124124
benchmark: bool = False,
125125
deterministic: bool = False,
126126
reload_dataloaders_every_epoch: bool = False,
@@ -212,7 +212,8 @@ def __init__(
212212
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
213213
Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.
214214
215-
profiler: To profile individual steps during training and assist in identifying bottlenecks.
215+
profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
216+
value is deprecated in v1.1 and will be removed in v1.3.
216217
217218
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0
218219

pytorch_lightning/utilities/argparse_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
174174
# if the only arg type is bool
175175
if len(arg_types) == 1:
176176
use_type = parsing.str_to_bool
177-
# if only two args (str, bool)
178-
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
177+
elif str in arg_types:
179178
use_type = parsing.str_to_bool_or_str
180179
else:
181180
# filter out the bool as we need to use more general

tests/test_deprecated.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
"""Test deprecated functionality which will be removed in vX.Y.Z"""
2+
from argparse import ArgumentParser
23
import pytest
34
import sys
5+
from unittest import mock
46

57
import torch
68

79
from tests.base import EvalModelTemplate
810
from pytorch_lightning.metrics.functional.classification import auc
911

12+
from pytorch_lightning import Trainer
1013
from pytorch_lightning.callbacks import ModelCheckpoint
14+
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler
1115
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1216

1317

@@ -22,6 +26,37 @@ def test_tbd_remove_in_v1_2_0():
2226
checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.')
2327

2428

29+
# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
30+
@pytest.mark.parametrize(['profiler', 'expected'], [
31+
(True, SimpleProfiler),
32+
(False, PassThroughProfiler),
33+
])
34+
def test_trainer_profiler_remove_in_v1_3_0(profiler, expected):
35+
with pytest.deprecated_call(match='will be removed in v1.3'):
36+
trainer = Trainer(profiler=profiler)
37+
assert isinstance(trainer.profiler, expected)
38+
39+
40+
@pytest.mark.parametrize(
41+
['cli_args', 'expected_parsed_arg', 'expected_profiler'],
42+
[
43+
('--profiler', True, SimpleProfiler),
44+
('--profiler True', True, SimpleProfiler),
45+
('--profiler False', False, PassThroughProfiler),
46+
],
47+
)
48+
def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, expected_profiler):
49+
cli_args = cli_args.split(' ')
50+
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
51+
parser = ArgumentParser(add_help=False)
52+
parser = Trainer.add_argparse_args(parent_parser=parser)
53+
args = Trainer.parse_argparser(parser)
54+
55+
assert getattr(args, "profiler") == expected_parsed_arg
56+
trainer = Trainer.from_argparse_args(args)
57+
assert isinstance(trainer.profiler, expected_profiler)
58+
59+
2560
def _soft_unimport_module(str_module):
2661
# once the module is imported e.g with parsing with pytest it lives in memory
2762
if str_module in sys.modules:

tests/trainer/test_trainer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
3333
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
3434
from pytorch_lightning.loggers import TensorBoardLogger
35+
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler
3536
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
3637
from pytorch_lightning.utilities.cloud_io import load as pl_load
3738
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -1408,3 +1409,32 @@ def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, l
14081409
trainer.fit(model)
14091410
expected_calls = [call(metrics=ANY, step=s) for s in range(log_interval - 1, max_steps, log_interval)]
14101411
log_metrics_mock.assert_has_calls(expected_calls)
1412+
1413+
1414+
@pytest.mark.parametrize(['profiler', 'expected'], [
1415+
(None, PassThroughProfiler),
1416+
(SimpleProfiler(), SimpleProfiler),
1417+
(AdvancedProfiler(), AdvancedProfiler),
1418+
('simple', SimpleProfiler),
1419+
('Simple', SimpleProfiler),
1420+
('advanced', AdvancedProfiler),
1421+
])
1422+
def test_trainer_profiler_correct_args(profiler, expected):
1423+
kwargs = {'profiler': profiler} if profiler is not None else {}
1424+
trainer = Trainer(**kwargs)
1425+
assert isinstance(trainer.profiler, expected)
1426+
1427+
1428+
def test_trainer_profiler_incorrect_str_arg():
1429+
with pytest.raises(ValueError, match=r".*can only be 'simple' or 'advanced'"):
1430+
Trainer(profiler="unknown_profiler")
1431+
1432+
1433+
@pytest.mark.parametrize('profiler', (
1434+
42, [42], {"a": 42}, torch.tensor(42), Trainer(),
1435+
))
1436+
def test_trainer_profiler_incorrect_arg_type(profiler):
1437+
with pytest.raises(MisconfigurationException,
1438+
match=r"Only None, bool, str and subclasses of `BaseProfiler` "
1439+
r"are valid values for `Trainer`'s `profiler` parameter. *"):
1440+
Trainer(profiler=profiler)

0 commit comments

Comments
 (0)