diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f005f583c5ed..51ad97decd867 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370)) +- Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](https://github.com/PyTorchLightning/pytorch-lightning/pull/6633)) + + - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 9f1bafe309f89..8d0add27cbb29 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -96,14 +96,14 @@ def start_training(self, trainer): stack.enter_context(optimizer.skip_synchronize()) # set up training routine - self._results = trainer.run_train() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() def start_evaluating(self, trainer): with ExitStack(): - self._results = trainer.run_evaluate() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() @@ -111,7 +111,7 @@ def start_evaluating(self, trainer): def start_predicting(self, trainer): with ExitStack(): # set up training routine - self._results = trainer.run_predict() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index b6f1be359bbf2..89f27963caadf 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -132,15 +132,15 @@ def rpc_enabled(self) -> bool: def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop - self._results = trainer.run_train() + self._results = trainer.run_stage() def start_evaluating(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop - self._results = trainer.run_evaluate() + self._results = trainer.run_stage() def start_predicting(self, trainer: 'Trainer') -> None: # double dispatch to initiate the predicting loop - self._results = trainer.run_predict() + self._results = trainer.run_stage() def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 55898dc2ee4e1..5668fd6654b2f 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -55,9 +55,23 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" - def teardown(self) -> None: - """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" - pass + def setup( + self, + stage: Optional[str] = None, + local_rank: Optional[int] = None, + log_dir: Optional[str] = None + ) -> None: + """Execute arbitrary pre-profiling set-up steps.""" + self.stage = stage + self.local_rank = local_rank + self.log_dir = log_dir + + def teardown(self, stage: Optional[str] = None) -> None: + """Execute arbitrary post-profiling tear-down steps.""" + self.stage = stage + if self.output_file: + self.output_file.close() + self.output_file = None @contextmanager def profile(self, action_name: str) -> None: @@ -94,13 +108,15 @@ def describe(self) -> None: """Logs a profile report after the conclusion of the training run.""" for write in self.write_streams: write(self.summary()) + if self.output_file is not None: + self.output_file.flush() @abstractmethod def summary(self) -> str: """Create profiler summary in text format.""" - def on_train_start(self, local_rank: Optional[int] = None): - self.local_rank = local_rank + def __del__(self): + self.teardown(None) class PassThroughProfiler(BaseProfiler): @@ -110,6 +126,7 @@ class PassThroughProfiler(BaseProfiler): """ def __init__(self): + self.output_file = None super().__init__(output_streams=None) def start(self, action_name: str) -> None: @@ -212,19 +229,6 @@ def log_row(action, mean, total): output_string += os.linesep return output_string - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - self.teardown() - - def teardown(self) -> None: - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - - def __del__(self): - self.teardown() - class AdvancedProfiler(BaseProfiler): """ @@ -285,16 +289,3 @@ def summary(self) -> str: output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - self.teardown() - - def teardown(self) -> None: - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - - def __del__(self): - self.teardown() diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index fdde80589acf3..c35979fa918af 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -162,11 +162,11 @@ def __init__( self.output_fname = output_filename self.output_file = None if local_rank is not None: - self.on_train_start(local_rank=local_rank) - self.on_train_start = super().on_train_start + self.setup(local_rank=local_rank) + self.setup = super().setup - def on_train_start(self, local_rank: Optional[str] = None): - self.local_rank = local_rank + def setup(self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None): + super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir) # when logging to `log.info`, only perform profiling on rank 0 if local_rank != 0 and self.output_fname is None: @@ -290,16 +290,3 @@ def summary(self) -> str: output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}") return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - self.teardown() - - def teardown(self) -> None: - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - - def __del__(self): - self.teardown() diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 98d65c1285ff7..e628d6d96bd19 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -54,6 +54,7 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]): ) self.trainer.profiler = profiler or PassThroughProfiler() - def on_train_start(self, trainer): + def setup(self) -> None: + trainer = self.trainer local_rank = trainer.local_rank if trainer.world_size > 1 else None - self.trainer.profiler.on_train_start(local_rank) + trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index b5654b148afc6..315e3c60c0557 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -491,6 +491,16 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self._running_stage = None + @property + def _setup_state(self) -> TrainerState: + # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" + return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + + @property + def _teardown_state(self) -> Optional[TrainerState]: + if self.state.running: + return self._setup_state + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a5b99871d55f9..f7bd1757b9bc2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -445,13 +445,15 @@ def fit( | || {self.dispatch} || | || LIGHTNING - {self.accelerator.start_training} or || - {self.accelerator.start_evaluating} or || FLOW - {self.accelerator.start_predicting} || + {self.accelerator.start_training} || + or {self.accelerator.start_evaluating} || + or {self.accelerator.start_predicting} || FLOW + | || + {self.run_stage} || | || DIRECTION - {self.run_train} or || - {self.run_evaluation} or || - {self.run_predict} || + {self.run_train} || + or {self.run_evaluation} || + or {self.run_predict} || | || results \/ This is used to guide readers to the core loops: train, test, predict. @@ -518,6 +520,9 @@ def dispatch(self): def run_stage(self): results = None + + self.profile_connector.setup() + if self.evaluating: results = self.run_evaluate() elif self.predicting: @@ -1060,8 +1065,7 @@ def tune( def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" - # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = self._setup_state if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') @@ -1072,12 +1076,8 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_teardown_hook(self, model: LightningModule) -> None: - if self.state.running: - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - else: - state = None - - self.profiler.teardown() + state = self._teardown_state + self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a77d91a7402b4..384a1b67a64f8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -102,9 +102,6 @@ def on_train_start(self): # hook self.trainer.call_hook("on_train_start") - # provide rank to profiler - self.trainer.profile_connector.on_train_start(self.trainer) - def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): diff --git a/tests/test_profiler.py b/tests/test_profiler.py index ccdd8a569c9a8..cc4fff3b7ede4 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -22,7 +22,8 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler, PyTorchProfiler +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -323,14 +324,16 @@ def test_profiler_teardown(tmpdir, cls): """ This test checks if profiler teardown method is called when trainer is exiting. """ + + class TestCallback(Callback): + + def on_fit_end(self, trainer, pl_module) -> None: + assert trainer.profiler.output_file is not None + profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt")) model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - profiler=profiler, - ) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) trainer.fit(model) - assert profiler.output_file.closed + assert profiler.output_file is None