Skip to content

Commit 2064ece

Browse files
tchatoncarmocca
andauthored
[refactor] Add setup to profilers + _run_stage_setup to trainer 2/5 (#6633)
* add setup * update * updates on comment * Minor changes * Extra import * Docs Co-authored-by: Carlos Mocholi <[email protected]>
1 parent e62c7c7 commit 2064ece

File tree

10 files changed

+72
-80
lines changed

10 files changed

+72
-80
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242
- Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370))
4343

4444

45+
- Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](https://github.com/PyTorchLightning/pytorch-lightning/pull/6633))
46+
47+
4548
- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
4649

4750

pytorch_lightning/plugins/training_type/horovod.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,22 +96,22 @@ def start_training(self, trainer):
9696
stack.enter_context(optimizer.skip_synchronize())
9797

9898
# set up training routine
99-
self._results = trainer.run_train()
99+
self._results = trainer.run_stage()
100100

101101
# Make sure all workers have finished training before returning to the user
102102
hvd.join()
103103

104104
def start_evaluating(self, trainer):
105105
with ExitStack():
106-
self._results = trainer.run_evaluate()
106+
self._results = trainer.run_stage()
107107

108108
# Make sure all workers have finished training before returning to the user
109109
hvd.join()
110110

111111
def start_predicting(self, trainer):
112112
with ExitStack():
113113
# set up training routine
114-
self._results = trainer.run_predict()
114+
self._results = trainer.run_stage()
115115

116116
# Make sure all workers have finished training before returning to the user
117117
hvd.join()

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,15 @@ def rpc_enabled(self) -> bool:
132132

133133
def start_training(self, trainer: 'Trainer') -> None:
134134
# double dispatch to initiate the training loop
135-
self._results = trainer.run_train()
135+
self._results = trainer.run_stage()
136136

137137
def start_evaluating(self, trainer: 'Trainer') -> None:
138138
# double dispatch to initiate the test loop
139-
self._results = trainer.run_evaluate()
139+
self._results = trainer.run_stage()
140140

141141
def start_predicting(self, trainer: 'Trainer') -> None:
142142
# double dispatch to initiate the predicting loop
143-
self._results = trainer.run_predict()
143+
self._results = trainer.run_stage()
144144

145145
def training_step(self, *args, **kwargs):
146146
return self.lightning_module.training_step(*args, **kwargs)

pytorch_lightning/profiler/profilers.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,23 @@ def start(self, action_name: str) -> None:
5555
def stop(self, action_name: str) -> None:
5656
"""Defines how to record the duration once an action is complete."""
5757

58-
def teardown(self) -> None:
59-
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""
60-
pass
58+
def setup(
59+
self,
60+
stage: Optional[str] = None,
61+
local_rank: Optional[int] = None,
62+
log_dir: Optional[str] = None
63+
) -> None:
64+
"""Execute arbitrary pre-profiling set-up steps."""
65+
self.stage = stage
66+
self.local_rank = local_rank
67+
self.log_dir = log_dir
68+
69+
def teardown(self, stage: Optional[str] = None) -> None:
70+
"""Execute arbitrary post-profiling tear-down steps."""
71+
self.stage = stage
72+
if self.output_file:
73+
self.output_file.close()
74+
self.output_file = None
6175

6276
@contextmanager
6377
def profile(self, action_name: str) -> None:
@@ -94,13 +108,15 @@ def describe(self) -> None:
94108
"""Logs a profile report after the conclusion of the training run."""
95109
for write in self.write_streams:
96110
write(self.summary())
111+
if self.output_file is not None:
112+
self.output_file.flush()
97113

98114
@abstractmethod
99115
def summary(self) -> str:
100116
"""Create profiler summary in text format."""
101117

102-
def on_train_start(self, local_rank: Optional[int] = None):
103-
self.local_rank = local_rank
118+
def __del__(self):
119+
self.teardown(None)
104120

105121

106122
class PassThroughProfiler(BaseProfiler):
@@ -110,6 +126,7 @@ class PassThroughProfiler(BaseProfiler):
110126
"""
111127

112128
def __init__(self):
129+
self.output_file = None
113130
super().__init__(output_streams=None)
114131

115132
def start(self, action_name: str) -> None:
@@ -212,19 +229,6 @@ def log_row(action, mean, total):
212229
output_string += os.linesep
213230
return output_string
214231

215-
def describe(self):
216-
"""Logs a profile report after the conclusion of the training run."""
217-
super().describe()
218-
self.teardown()
219-
220-
def teardown(self) -> None:
221-
"""Close profiler's stream."""
222-
if self.output_file:
223-
self.output_file.close()
224-
225-
def __del__(self):
226-
self.teardown()
227-
228232

229233
class AdvancedProfiler(BaseProfiler):
230234
"""
@@ -285,16 +289,3 @@ def summary(self) -> str:
285289
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"
286290

287291
return output_string
288-
289-
def describe(self):
290-
"""Logs a profile report after the conclusion of the training run."""
291-
super().describe()
292-
self.teardown()
293-
294-
def teardown(self) -> None:
295-
"""Close profiler's stream."""
296-
if self.output_file:
297-
self.output_file.close()
298-
299-
def __del__(self):
300-
self.teardown()

pytorch_lightning/profiler/pytorch.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,11 @@ def __init__(
162162
self.output_fname = output_filename
163163
self.output_file = None
164164
if local_rank is not None:
165-
self.on_train_start(local_rank=local_rank)
166-
self.on_train_start = super().on_train_start
165+
self.setup(local_rank=local_rank)
166+
self.setup = super().setup
167167

168-
def on_train_start(self, local_rank: Optional[str] = None):
169-
self.local_rank = local_rank
168+
def setup(self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None):
169+
super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir)
170170

171171
# when logging to `log.info`, only perform profiling on rank 0
172172
if local_rank != 0 and self.output_fname is None:
@@ -290,16 +290,3 @@ def summary(self) -> str:
290290
output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}")
291291

292292
return output_string
293-
294-
def describe(self):
295-
"""Logs a profile report after the conclusion of the training run."""
296-
super().describe()
297-
self.teardown()
298-
299-
def teardown(self) -> None:
300-
"""Close profiler's stream."""
301-
if self.output_file:
302-
self.output_file.close()
303-
304-
def __del__(self):
305-
self.teardown()

pytorch_lightning/trainer/connectors/profiler_connector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]):
5454
)
5555
self.trainer.profiler = profiler or PassThroughProfiler()
5656

57-
def on_train_start(self, trainer):
57+
def setup(self) -> None:
58+
trainer = self.trainer
5859
local_rank = trainer.local_rank if trainer.world_size > 1 else None
59-
self.trainer.profiler.on_train_start(local_rank)
60+
trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir)

pytorch_lightning/trainer/properties.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,16 @@ def sanity_checking(self, val: bool) -> None:
491491
elif self.sanity_checking:
492492
self._running_stage = None
493493

494+
@property
495+
def _setup_state(self) -> TrainerState:
496+
# 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders"
497+
return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state
498+
499+
@property
500+
def _teardown_state(self) -> Optional[TrainerState]:
501+
if self.state.running:
502+
return self._setup_state
503+
494504

495505
# Used to represent the concrete type TrainerProperties class methods are called on.
496506
_T = TypeVar('_T', bound=TrainerProperties)

pytorch_lightning/trainer/trainer.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -445,13 +445,15 @@ def fit(
445445
| ||
446446
{self.dispatch} ||
447447
| || LIGHTNING
448-
{self.accelerator.start_training} or ||
449-
{self.accelerator.start_evaluating} or || FLOW
450-
{self.accelerator.start_predicting} ||
448+
{self.accelerator.start_training} ||
449+
or {self.accelerator.start_evaluating} ||
450+
or {self.accelerator.start_predicting} || FLOW
451+
| ||
452+
{self.run_stage} ||
451453
| || DIRECTION
452-
{self.run_train} or ||
453-
{self.run_evaluation} or ||
454-
{self.run_predict} ||
454+
{self.run_train} ||
455+
or {self.run_evaluation} ||
456+
or {self.run_predict} ||
455457
| ||
456458
results \/
457459
This is used to guide readers to the core loops: train, test, predict.
@@ -518,6 +520,9 @@ def dispatch(self):
518520

519521
def run_stage(self):
520522
results = None
523+
524+
self.profile_connector.setup()
525+
521526
if self.evaluating:
522527
results = self.run_evaluate()
523528
elif self.predicting:
@@ -1060,8 +1065,7 @@ def tune(
10601065

10611066
def call_setup_hook(self, model: LightningModule) -> None:
10621067
assert self.state.running, f"TrainerState: {self.state}"
1063-
# 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders"
1064-
state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state
1068+
state = self._setup_state
10651069

10661070
if self.datamodule is not None:
10671071
called = getattr(self.datamodule, f'has_setup_{state}')
@@ -1072,12 +1076,8 @@ def call_setup_hook(self, model: LightningModule) -> None:
10721076
model.setup(stage=state)
10731077

10741078
def call_teardown_hook(self, model: LightningModule) -> None:
1075-
if self.state.running:
1076-
state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state
1077-
else:
1078-
state = None
1079-
1080-
self.profiler.teardown()
1079+
state = self._teardown_state
1080+
self.profiler.teardown(stage=state)
10811081
self.teardown(stage=state)
10821082
model.teardown(stage=state)
10831083

pytorch_lightning/trainer/training_loop.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ def on_train_start(self):
102102
# hook
103103
self.trainer.call_hook("on_train_start")
104104

105-
# provide rank to profiler
106-
self.trainer.profile_connector.on_train_start(self.trainer)
107-
108105
def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
109106
# clean hparams
110107
if hasattr(model, "hparams"):

tests/test_profiler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import torch
2323

2424
from pytorch_lightning import Trainer
25-
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler, PyTorchProfiler
25+
from pytorch_lightning.callbacks import Callback
26+
from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler
2627
from tests.helpers import BoringModel
2728
from tests.helpers.runif import RunIf
2829

@@ -323,14 +324,16 @@ def test_profiler_teardown(tmpdir, cls):
323324
"""
324325
This test checks if profiler teardown method is called when trainer is exiting.
325326
"""
327+
328+
class TestCallback(Callback):
329+
330+
def on_fit_end(self, trainer, pl_module) -> None:
331+
assert trainer.profiler.output_file is not None
332+
326333
profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt"))
327334

328335
model = BoringModel()
329-
trainer = Trainer(
330-
default_root_dir=tmpdir,
331-
fast_dev_run=True,
332-
profiler=profiler,
333-
)
336+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()])
334337
trainer.fit(model)
335338

336-
assert profiler.output_file.closed
339+
assert profiler.output_file is None

0 commit comments

Comments
 (0)