Skip to content

Commit 60fe152

Browse files
Remove calls to profile model_forward (#12032)
Co-authored-by: ananthsub <[email protected]>
1 parent e50653d commit 60fe152

File tree

4 files changed

+38
-42
lines changed

4 files changed

+38
-42
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
567567
- Removed `log_text` and `log_image` from the `LightningLoggerBase` API ([#11857](https://github.com/PyTorchLightning/pytorch-lightning/pull/11857))
568568

569569

570+
- Removed calls to `profile("model_forward")` in favor of profiling `training_step` ([#12032](https://github.com/PyTorchLightning/pytorch-lightning/pull/12032))
571+
570572
### Fixed
571573

572574
- Fixed an issue where `HorovodStrategy.teardown()` did not complete gracefully if an exception was thrown during callback setup [#11752](https://github.com/PyTorchLightning/pytorch-lightning/pull/11752)

docs/source/advanced/profiler.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ PyTorch Lightning supports profiling standard actions in the training loop out o
1919
- on_train_epoch_start
2020
- on_train_epoch_end
2121
- on_train_batch_start
22-
- model_forward
2322
- model_backward
2423
- on_after_backward
2524
- optimizer_step
@@ -66,7 +65,6 @@ The profiler's results will be printed at the completion of a training ``trainer
6665
| run_training_epoch | 6.1558 | 6.1558 |
6766
| run_training_batch | 0.0022506 | 0.015754 |
6867
| [LightningModule]BoringModel.optimizer_step | 0.0017477 | 0.012234 |
69-
| model_forward | 0.00055868 | 0.0039108 |
7068
| [LightningModule]BoringModel.val_dataloader | 0.00024388 | 0.00024388 |
7169
| on_train_batch_start | 0.00014637 | 0.0010246 |
7270
| [LightningModule]BoringModel.teardown | 2.15e-06 | 2.15e-06 |

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -107,30 +107,28 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
107107
assert self.trainer is not None
108108
lightning_module = self.trainer.lightning_module
109109

110-
with self.trainer.profiler.profile("model_forward"):
110+
step_kwargs = _build_training_step_kwargs(
111+
lightning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens
112+
)
111113

112-
step_kwargs = _build_training_step_kwargs(
113-
lightning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens
114-
)
115-
116-
# manually capture logged metrics
117-
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
118-
self.trainer.strategy.post_training_step()
114+
# manually capture logged metrics
115+
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
116+
self.trainer.strategy.post_training_step()
119117

120-
del step_kwargs
118+
del step_kwargs
121119

122-
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
123-
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
124-
training_step_output = strategy_output if model_output is None else model_output
125-
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
120+
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
121+
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
122+
training_step_output = strategy_output if model_output is None else model_output
123+
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
126124

127-
result = self.output_result_cls.from_training_step_output(training_step_output)
125+
result = self.output_result_cls.from_training_step_output(training_step_output)
128126

129-
if self.trainer.move_metrics_to_cpu:
130-
# hiddens and the training step output are not moved as they are not considered "metrics"
131-
# the user might need them on the correct device for an operation in `training_epoch_end`
132-
assert self.trainer._results is not None
133-
self.trainer._results.cpu()
127+
if self.trainer.move_metrics_to_cpu:
128+
# hiddens and the training step output are not moved as they are not considered "metrics"
129+
# the user might need them on the correct device for an operation in `training_epoch_end`
130+
assert self.trainer._results is not None
131+
self.trainer._results.cpu()
134132

135133
self._done = True
136134
self._output = result.asdict()

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -414,32 +414,30 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
414414
# give the PL module a result for logging
415415
lightning_module = self.trainer.lightning_module
416416

417-
with self.trainer.profiler.profile("model_forward"):
418-
419-
step_kwargs = _build_training_step_kwargs(
420-
lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
421-
)
417+
step_kwargs = _build_training_step_kwargs(
418+
lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
419+
)
422420

423-
# manually capture logged metrics
424-
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
425-
self.trainer.strategy.post_training_step()
421+
# manually capture logged metrics
422+
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
423+
self.trainer.strategy.post_training_step()
426424

427-
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
428-
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
429-
training_step_output = strategy_output if model_output is None else model_output
425+
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
426+
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
427+
training_step_output = strategy_output if model_output is None else model_output
430428

431-
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
429+
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
432430

433-
result = self.output_result_cls.from_training_step_output(
434-
training_step_output, self.trainer.accumulate_grad_batches
435-
)
431+
result = self.output_result_cls.from_training_step_output(
432+
training_step_output, self.trainer.accumulate_grad_batches
433+
)
436434

437-
if self.trainer._terminate_on_nan:
438-
check_finite_loss(result.closure_loss)
435+
if self.trainer._terminate_on_nan:
436+
check_finite_loss(result.closure_loss)
439437

440-
if self.trainer.move_metrics_to_cpu:
441-
# hiddens and the training step output are not moved as they are not considered "metrics"
442-
assert self.trainer._results is not None
443-
self.trainer._results.cpu()
438+
if self.trainer.move_metrics_to_cpu:
439+
# hiddens and the training step output are not moved as they are not considered "metrics"
440+
assert self.trainer._results is not None
441+
self.trainer._results.cpu()
444442

445443
return result

0 commit comments

Comments
 (0)