From 30d76b022351d91c883b2185b683799147696e9c Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Mar 2021 15:34:27 +0000 Subject: [PATCH 1/6] add setup --- .../plugins/training_type/horovod.py | 6 +-- .../training_type/training_type_plugin.py | 6 +-- pytorch_lightning/profiler/profilers.py | 44 ++++++------------- pytorch_lightning/profiler/pytorch.py | 21 ++------- .../trainer/connectors/profiler_connector.py | 6 ++- pytorch_lightning/trainer/trainer.py | 22 +++++++--- pytorch_lightning/trainer/training_loop.py | 3 -- tests/test_profiler.py | 17 ++++--- 8 files changed, 53 insertions(+), 72 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 9f1bafe309f89..8d0add27cbb29 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -96,14 +96,14 @@ def start_training(self, trainer): stack.enter_context(optimizer.skip_synchronize()) # set up training routine - self._results = trainer.run_train() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() def start_evaluating(self, trainer): with ExitStack(): - self._results = trainer.run_evaluate() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() @@ -111,7 +111,7 @@ def start_evaluating(self, trainer): def start_predicting(self, trainer): with ExitStack(): # set up training routine - self._results = trainer.run_predict() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index b6f1be359bbf2..89f27963caadf 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -132,15 +132,15 @@ def rpc_enabled(self) -> bool: def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop - self._results = trainer.run_train() + self._results = trainer.run_stage() def start_evaluating(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop - self._results = trainer.run_evaluate() + self._results = trainer.run_stage() def start_predicting(self, trainer: 'Trainer') -> None: # double dispatch to initiate the predicting loop - self._results = trainer.run_predict() + self._results = trainer.run_stage() def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 55898dc2ee4e1..8b7352e89fe14 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -55,9 +55,17 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" - def teardown(self) -> None: + def setup(self, stage: str, local_rank: Optional[int], log_dir: Optional[str]): """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" - pass + self.stage = stage + self.local_rank = local_rank + self.log_dir = log_dir + + def teardown(self, stage: Optional[str] = None) -> None: + """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + if self.output_file: + self.output_file.close() + self.output_file = None @contextmanager def profile(self, action_name: str) -> None: @@ -94,13 +102,15 @@ def describe(self) -> None: """Logs a profile report after the conclusion of the training run.""" for write in self.write_streams: write(self.summary()) + if self.output_file is not None: + self.output_file.flush() @abstractmethod def summary(self) -> str: """Create profiler summary in text format.""" - def on_train_start(self, local_rank: Optional[int] = None): - self.local_rank = local_rank + def __del__(self): + self.teardown(None) class PassThroughProfiler(BaseProfiler): @@ -212,19 +222,6 @@ def log_row(action, mean, total): output_string += os.linesep return output_string - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - self.teardown() - - def teardown(self) -> None: - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - - def __del__(self): - self.teardown() - class AdvancedProfiler(BaseProfiler): """ @@ -285,16 +282,3 @@ def summary(self) -> str: output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - self.teardown() - - def teardown(self) -> None: - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - - def __del__(self): - self.teardown() diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index fdde80589acf3..0b2b20cdc7814 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -162,11 +162,11 @@ def __init__( self.output_fname = output_filename self.output_file = None if local_rank is not None: - self.on_train_start(local_rank=local_rank) - self.on_train_start = super().on_train_start + self.setup(None, local_rank, None) + self.setup = super().setup - def on_train_start(self, local_rank: Optional[str] = None): - self.local_rank = local_rank + def setup(self, stage: str, local_rank: Optional[int], log_dir: Optional[str]): + super().setup(stage, local_rank, log_dir) # when logging to `log.info`, only perform profiling on rank 0 if local_rank != 0 and self.output_fname is None: @@ -290,16 +290,3 @@ def summary(self) -> str: output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}") return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - self.teardown() - - def teardown(self) -> None: - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - - def __del__(self): - self.teardown() diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 98d65c1285ff7..45d51aeaaa952 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -54,6 +54,8 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]): ) self.trainer.profiler = profiler or PassThroughProfiler() - def on_train_start(self, trainer): + def on_run_stage_setup(self): + trainer = self.trainer local_rank = trainer.local_rank if trainer.world_size > 1 else None - self.trainer.profiler.on_train_start(local_rank) + trainer.profiler.lightning_module = trainer.lightning_module + trainer.profiler.setup(trainer.state, local_rank, trainer.log_dir) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a5b99871d55f9..40b59cf872520 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -445,13 +445,15 @@ def fit( | || {self.dispatch} || | || LIGHTNING - {self.accelerator.start_training} or || - {self.accelerator.start_evaluating} or || FLOW - {self.accelerator.start_predicting} || + {self.accelerator.start_training} || + or {self.accelerator.start_evaluating} || + or {self.accelerator.start_predicting} || FLOW + | || + {self.run_stage} || | || DIRECTION - {self.run_train} or || - {self.run_evaluation} or || - {self.run_predict} || + {self.run_train} || + or {self.run_evaluation} || + or {self.run_predict} || | || results \/ This is used to guide readers to the core loops: train, test, predict. @@ -518,6 +520,9 @@ def dispatch(self): def run_stage(self): results = None + + self._run_stage_setup() + if self.evaluating: results = self.run_evaluate() elif self.predicting: @@ -526,6 +531,9 @@ def run_stage(self): self.run_train() return results + def _run_stage_setup(self): + self.profile_connector.on_run_stage_setup() + def _pre_training_routine(self): # wait for all to join if on distributed self.accelerator.barrier("setup_training") @@ -1077,7 +1085,7 @@ def call_teardown_hook(self, model: LightningModule) -> None: else: state = None - self.profiler.teardown() + self.profiler.teardown(stage=state.value) self.teardown(stage=state) model.teardown(stage=state) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a77d91a7402b4..384a1b67a64f8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -102,9 +102,6 @@ def on_train_start(self): # hook self.trainer.call_hook("on_train_start") - # provide rank to profiler - self.trainer.profile_connector.on_train_start(self.trainer) - def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): diff --git a/tests/test_profiler.py b/tests/test_profiler.py index ccdd8a569c9a8..cc4fff3b7ede4 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -22,7 +22,8 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler, PyTorchProfiler +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -323,14 +324,16 @@ def test_profiler_teardown(tmpdir, cls): """ This test checks if profiler teardown method is called when trainer is exiting. """ + + class TestCallback(Callback): + + def on_fit_end(self, trainer, pl_module) -> None: + assert trainer.profiler.output_file is not None + profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt")) model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - profiler=profiler, - ) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) trainer.fit(model) - assert profiler.output_file.closed + assert profiler.output_file is None From 32d754619273bbd1b93f1ef2746247684e6ac3f3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Mar 2021 15:54:09 +0000 Subject: [PATCH 2/6] update --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 40b59cf872520..85b5c12a26efc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1085,7 +1085,7 @@ def call_teardown_hook(self, model: LightningModule) -> None: else: state = None - self.profiler.teardown(stage=state.value) + self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) From 0f3cef917cf851a06fe7d24766a982a8f70bbb28 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Mar 2021 16:29:34 +0000 Subject: [PATCH 3/6] updates on comment --- pytorch_lightning/profiler/profilers.py | 1 + .../trainer/connectors/profiler_connector.py | 5 ++--- pytorch_lightning/trainer/properties.py | 11 +++++++++++ pytorch_lightning/trainer/trainer.py | 14 +++----------- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 8b7352e89fe14..7e8aba3d73c15 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -120,6 +120,7 @@ class PassThroughProfiler(BaseProfiler): """ def __init__(self): + self.output_file = None super().__init__(output_streams=None) def start(self, action_name: str) -> None: diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 45d51aeaaa952..2088099cf2a8a 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -54,8 +54,7 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]): ) self.trainer.profiler = profiler or PassThroughProfiler() - def on_run_stage_setup(self): + def setup(self): trainer = self.trainer local_rank = trainer.local_rank if trainer.world_size > 1 else None - trainer.profiler.lightning_module = trainer.lightning_module - trainer.profiler.setup(trainer.state, local_rank, trainer.log_dir) + trainer.profiler.setup(trainer.setup_state, local_rank, trainer.log_dir) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index b5654b148afc6..1594eeaadc8b1 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -491,6 +491,17 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self._running_stage = None + @property + def setup_state(self): + # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" + return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + + @property + def teardown_state(self): + if self.state.running: + return self.setup_state + return None + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 85b5c12a26efc..011287024bcaf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -521,7 +521,7 @@ def dispatch(self): def run_stage(self): results = None - self._run_stage_setup() + self.profile_connector.setup() if self.evaluating: results = self.run_evaluate() @@ -531,9 +531,6 @@ def run_stage(self): self.run_train() return results - def _run_stage_setup(self): - self.profile_connector.on_run_stage_setup() - def _pre_training_routine(self): # wait for all to join if on distributed self.accelerator.barrier("setup_training") @@ -1068,8 +1065,7 @@ def tune( def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" - # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = self.setup_state if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') @@ -1080,11 +1076,7 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_teardown_hook(self, model: LightningModule) -> None: - if self.state.running: - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - else: - state = None - + state = self.teardown_state self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) From a01878db3343d64cb9c5391ba875a8780bf01886 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Mar 2021 17:40:24 +0100 Subject: [PATCH 4/6] Minor changes --- pytorch_lightning/profiler/profilers.py | 8 +++++++- pytorch_lightning/profiler/pytorch.py | 6 +++--- .../trainer/connectors/profiler_connector.py | 6 ++++-- pytorch_lightning/trainer/properties.py | 7 +++---- pytorch_lightning/trainer/trainer.py | 8 ++++---- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 7e8aba3d73c15..cb41fd04b2b7d 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -55,7 +55,12 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" - def setup(self, stage: str, local_rank: Optional[int], log_dir: Optional[str]): + def setup( + self, + stage: Optional[str] = None, + local_rank: Optional[int] = None, + log_dir: Optional[str] = None + ) -> None: """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" self.stage = stage self.local_rank = local_rank @@ -63,6 +68,7 @@ def setup(self, stage: str, local_rank: Optional[int], log_dir: Optional[str]): def teardown(self, stage: Optional[str] = None) -> None: """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + self.stage = stage if self.output_file: self.output_file.close() self.output_file = None diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 0b2b20cdc7814..c35979fa918af 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -162,11 +162,11 @@ def __init__( self.output_fname = output_filename self.output_file = None if local_rank is not None: - self.setup(None, local_rank, None) + self.setup(local_rank=local_rank) self.setup = super().setup - def setup(self, stage: str, local_rank: Optional[int], log_dir: Optional[str]): - super().setup(stage, local_rank, log_dir) + def setup(self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None): + super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir) # when logging to `log.info`, only perform profiling on rank 0 if local_rank != 0 and self.output_fname is None: diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 2088099cf2a8a..a3af10877325b 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -14,6 +14,8 @@ from typing import Union +from transformers import TrainerState + from pytorch_lightning.profiler import ( AdvancedProfiler, BaseProfiler, @@ -54,7 +56,7 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]): ) self.trainer.profiler = profiler or PassThroughProfiler() - def setup(self): + def setup(self) -> None: trainer = self.trainer local_rank = trainer.local_rank if trainer.world_size > 1 else None - trainer.profiler.setup(trainer.setup_state, local_rank, trainer.log_dir) + trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 1594eeaadc8b1..315e3c60c0557 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -492,15 +492,14 @@ def sanity_checking(self, val: bool) -> None: self._running_stage = None @property - def setup_state(self): + def _setup_state(self) -> TrainerState: # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state @property - def teardown_state(self): + def _teardown_state(self) -> Optional[TrainerState]: if self.state.running: - return self.setup_state - return None + return self._setup_state # Used to represent the concrete type TrainerProperties class methods are called on. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 011287024bcaf..f7bd1757b9bc2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -445,7 +445,7 @@ def fit( | || {self.dispatch} || | || LIGHTNING - {self.accelerator.start_training} || + {self.accelerator.start_training} || or {self.accelerator.start_evaluating} || or {self.accelerator.start_predicting} || FLOW | || @@ -453,7 +453,7 @@ def fit( | || DIRECTION {self.run_train} || or {self.run_evaluation} || - or {self.run_predict} || + or {self.run_predict} || | || results \/ This is used to guide readers to the core loops: train, test, predict. @@ -1065,7 +1065,7 @@ def tune( def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" - state = self.setup_state + state = self._setup_state if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') @@ -1076,7 +1076,7 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_teardown_hook(self, model: LightningModule) -> None: - state = self.teardown_state + state = self._teardown_state self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) From 4d84040ba8d9854ed7ac1144a519465d2469fccf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Mar 2021 17:41:33 +0100 Subject: [PATCH 5/6] Extra import --- pytorch_lightning/trainer/connectors/profiler_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index a3af10877325b..e628d6d96bd19 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -14,8 +14,6 @@ from typing import Union -from transformers import TrainerState - from pytorch_lightning.profiler import ( AdvancedProfiler, BaseProfiler, From 4d35c58240e3e2132474e09eb960562c1c380272 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Mar 2021 17:46:18 +0100 Subject: [PATCH 6/6] Docs --- CHANGELOG.md | 3 +++ pytorch_lightning/profiler/profilers.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f005f583c5ed..51ad97decd867 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370)) +- Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](https://github.com/PyTorchLightning/pytorch-lightning/pull/6633)) + + - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index cb41fd04b2b7d..5668fd6654b2f 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -61,13 +61,13 @@ def setup( local_rank: Optional[int] = None, log_dir: Optional[str] = None ) -> None: - """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + """Execute arbitrary pre-profiling set-up steps.""" self.stage = stage self.local_rank = local_rank self.log_dir = log_dir def teardown(self, stage: Optional[str] = None) -> None: - """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + """Execute arbitrary post-profiling tear-down steps.""" self.stage = stage if self.output_file: self.output_file.close()