Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7373bc7
allow trainer's profiler param to have a str value
ddrevicky Sep 25, 2020
a5212cd
add tests
ddrevicky Sep 25, 2020
5a9936f
update docs
ddrevicky Sep 25, 2020
5e9f895
update exception message
ddrevicky Sep 25, 2020
a91bab3
Update CHANGELOG
ddrevicky Sep 25, 2020
6c44099
fix pep8 issues
ddrevicky Sep 25, 2020
3661661
cleanup test code
ddrevicky Sep 25, 2020
7e82c3c
Add deprecation warning if using bool for profiler
ddrevicky Oct 1, 2020
01b1bc9
Add deprecation tests and move deprecated tests
ddrevicky Oct 1, 2020
f8bb828
Remove bool option to profiler from docs
ddrevicky Oct 1, 2020
264a08b
Deprecate bool args to profiler in CHANGELOG
ddrevicky Oct 1, 2020
e81ee90
fixup! Add deprecation warning if using bool for profiler
ddrevicky Oct 1, 2020
b0878ff
fixup! Add deprecation tests and move deprecated tests
ddrevicky Oct 1, 2020
b07e480
Apply suggestions from code review
ddrevicky Oct 2, 2020
bfc4650
Implement suggestions, remove whitespace
ddrevicky Oct 2, 2020
a660ba2
fixup! Implement suggestions, remove whitespace
ddrevicky Oct 2, 2020
7761151
Allow bool, str (case insensitive), BaseProfiler
ddrevicky Oct 2, 2020
f673bc2
Add info about bool deprecation to trainer
ddrevicky Oct 2, 2020
f2780e7
fixup! Add info about bool deprecation to trainer
ddrevicky Oct 2, 2020
443d626
Move deprecate todo to test_deprecated
ddrevicky Oct 2, 2020
b561bf6
Test wrong profiler type, improve error message
ddrevicky Oct 2, 2020
2042512
fixup! Test wrong profiler type, improve error message
ddrevicky Oct 2, 2020
3035dda
Update pytorch_lightning/trainer/connectors/profiler_connector.py
ddrevicky Oct 2, 2020
348a129
Apply suggestions from code review
Borda Oct 2, 2020
29f7833
Readd bool to profiler types, test cli profiler arg
ddrevicky Oct 3, 2020
b5cecd6
Remove extra whitespace in doc
ddrevicky Oct 24, 2020
4f92fba
Apply suggestions from code review
Borda Oct 24, 2020
f911f37
Update deprecation versions
ddrevicky Oct 25, 2020
c45d465
Merge branch 'master' into feature/3330_trainer_profiler_str
rohitgr7 Oct 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added autogenerated helptext to `Trainer.add_argparse_args`. ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344))


- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))


### Changed


Expand All @@ -48,6 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237))


- Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))


### Removed


Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
Enable simple profiling
-----------------------

If you only wish to profile the standard actions, you can set `profiler=True` when constructing
your `Trainer` object.
If you only wish to profile the standard actions, you can set `profiler="simple"`
when constructing your `Trainer` object.

.. code-block:: python

trainer = Trainer(..., profiler=True)
trainer = Trainer(..., profiler="simple")

The profiler's results will be printed at the completion of a training `fit()`.

Expand Down Expand Up @@ -59,6 +59,10 @@

.. code-block:: python

trainer = Trainer(..., profiler="advanced")

or

profiler = AdvancedProfiler()
trainer = Trainer(..., profiler=profiler)

Expand Down
11 changes: 4 additions & 7 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,14 +1199,11 @@ def world_size(self):
# default used by the Trainer
trainer = Trainer(profiler=None)

# to profile standard training events
trainer = Trainer(profiler=True)
# to profile standard training events, equivalent to `profiler=SimpleProfiler()`
trainer = Trainer(profiler="simple")

# equivalent to profiler=True
trainer = Trainer(profiler=SimpleProfiler())

# advanced profiler for function-level stats
trainer = Trainer(profiler=AdvancedProfiler())
# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()`
trainer = Trainer(profiler="advanced")

progress_bar_refresh_rate
^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
34 changes: 29 additions & 5 deletions pytorch_lightning/trainer/connectors/profiler_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,40 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from pytorch_lightning.profiler import PassThroughProfiler, SimpleProfiler

from typing import Union

from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler, AdvancedProfiler
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class ProfilerConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, profiler):
# configure profiler
if profiler is True:
profiler = SimpleProfiler()
def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):

if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
# TODO: Update exception on removal of bool
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` "
"are valid values for `Trainer`'s `profiler` parameter. "
f"Received {profiler} which is of type {type(profiler)}.")

if isinstance(profiler, bool):
rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.",
DeprecationWarning)
if profiler:
profiler = SimpleProfiler()
elif isinstance(profiler, str):
profiler = profiler.lower()
if profiler == "simple":
profiler = SimpleProfiler()
elif profiler == "advanced":
profiler = AdvancedProfiler()
else:
raise ValueError("When passing string value for the `profiler` parameter of"
" `Trainer`, it can only be 'simple' or 'advanced'")
self.trainer.profiler = profiler or PassThroughProfiler()
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[str] = None,
profiler: Optional[Union[BaseProfiler, bool]] = None,
profiler: Optional[Union[BaseProfiler, bool, str]] = None,
benchmark: bool = False,
deterministic: bool = False,
reload_dataloaders_every_epoch: bool = False,
Expand Down Expand Up @@ -212,7 +212,8 @@ def __init__(
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.

profiler: To profile individual steps during training and assist in identifying bottlenecks.
profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
value is deprecated in v1.1 and will be removed in v1.3.

overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/utilities/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
# if the only arg type is bool
if len(arg_types) == 1:
use_type = parsing.str_to_bool
# if only two args (str, bool)
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
elif str in arg_types:
use_type = parsing.str_to_bool_or_str
else:
# filter out the bool as we need to use more general
Expand Down
35 changes: 35 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Test deprecated functionality which will be removed in vX.Y.Z"""
from argparse import ArgumentParser
import pytest
import sys
from unittest import mock

import torch

from tests.base import EvalModelTemplate
from pytorch_lightning.metrics.functional.classification import auc

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand All @@ -22,6 +26,37 @@ def test_tbd_remove_in_v1_2_0():
checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.')


# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
@pytest.mark.parametrize(['profiler', 'expected'], [
(True, SimpleProfiler),
(False, PassThroughProfiler),
])
def test_trainer_profiler_remove_in_v1_3_0(profiler, expected):
with pytest.deprecated_call(match='will be removed in v1.3'):
trainer = Trainer(profiler=profiler)
assert isinstance(trainer.profiler, expected)


@pytest.mark.parametrize(
['cli_args', 'expected_parsed_arg', 'expected_profiler'],
[
('--profiler', True, SimpleProfiler),
('--profiler True', True, SimpleProfiler),
('--profiler False', False, PassThroughProfiler),
],
)
def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, expected_profiler):
cli_args = cli_args.split(' ')
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)

assert getattr(args, "profiler") == expected_parsed_arg
trainer = Trainer.from_argparse_args(args)
assert isinstance(trainer.profiler, expected_profiler)


def _soft_unimport_module(str_module):
# once the module is imported e.g with parsing with pytest it lives in memory
if str_module in sys.modules:
Expand Down
30 changes: 30 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -1408,3 +1409,32 @@ def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, l
trainer.fit(model)
expected_calls = [call(metrics=ANY, step=s) for s in range(log_interval - 1, max_steps, log_interval)]
log_metrics_mock.assert_has_calls(expected_calls)


@pytest.mark.parametrize(['profiler', 'expected'], [
(None, PassThroughProfiler),
(SimpleProfiler(), SimpleProfiler),
(AdvancedProfiler(), AdvancedProfiler),
('simple', SimpleProfiler),
('Simple', SimpleProfiler),
('advanced', AdvancedProfiler),
])
def test_trainer_profiler_correct_args(profiler, expected):
kwargs = {'profiler': profiler} if profiler is not None else {}
trainer = Trainer(**kwargs)
assert isinstance(trainer.profiler, expected)


def test_trainer_profiler_incorrect_str_arg():
with pytest.raises(ValueError, match=r".*can only be 'simple' or 'advanced'"):
Trainer(profiler="unknown_profiler")


@pytest.mark.parametrize('profiler', (
42, [42], {"a": 42}, torch.tensor(42), Trainer(),
))
def test_trainer_profiler_incorrect_arg_type(profiler):
with pytest.raises(MisconfigurationException,
match=r"Only None, bool, str and subclasses of `BaseProfiler` "
r"are valid values for `Trainer`'s `profiler` parameter. *"):
Trainer(profiler=profiler)