Skip to content

Commit 0e19d16

Browse files
authored
Move result teardown to loops (#8245)
* Move result teardown to loops * Update CHANGELOG * Remove teardown from run * Move previous teardown to on_run_end * Add comment * Merge 8250 * Remove stage set to None where it shouldnt
1 parent f3e74ab commit 0e19d16

File tree

9 files changed

+28
-25
lines changed

9 files changed

+28
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
170170
* Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065))
171171
* Refactored prediction loop interface; added new classes `PredictionLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077))
172172
* Removed `pytorch_lightning/trainer/predict_loop.py` ([#8094](https://github.com/PyTorchLightning/pytorch-lightning/pull/8094))
173+
* Moved result teardown to the loops ([#8245](https://github.com/PyTorchLightning/pytorch-lightning/pull/8245))
173174

174175

175176
- Refactored logging

pytorch_lightning/loops/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
114114
break
115115

116116
output = self.on_run_end()
117-
self.teardown()
118117
return output
119118

120119
def restore(self) -> None:
@@ -149,7 +148,7 @@ def on_run_end(self) -> Any:
149148
"""Hook to be called at the end of the run. Its return argument is returned from :attr:`run`."""
150149

151150
def teardown(self) -> None:
152-
"""The very last method called inside :meth:`run`. Use to release memory etc."""
151+
"""Use to release memory etc."""
153152

154153
def load_state_dict(self, state_dict: Dict) -> None:
155154
"""Restore the loop state from the provided state_dict."""

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,7 @@ def on_evaluation_epoch_end(self) -> None:
263263
self.trainer.call_hook(hook_name)
264264
self.trainer.call_hook("on_epoch_end")
265265
self.trainer.logger_connector.on_epoch_end()
266+
267+
def teardown(self) -> None:
268+
self._results.cpu()
269+
self.epoch_loop.teardown()

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,10 @@ def advance(
119119

120120
def on_run_end(self) -> List[STEP_OUTPUT]:
121121
"""Returns the outputs of the whole run"""
122-
return self.outputs
123-
124-
def teardown(self) -> None:
125-
"""Frees memory of tracked outputs"""
122+
outputs = self.outputs
123+
# free memory
126124
self.outputs = []
125+
return outputs
127126

128127
def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]:
129128
"""The evaluation step (validation_step or test_step depending on the trainer's state).

pytorch_lightning/loops/epoch/prediction_epoch_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ def advance(
8888

8989
def on_run_end(self) -> Tuple[Any, Any]:
9090
"""Returns the predictions and the corresponding batch indices"""
91-
return self.predictions, self._all_batch_indices
92-
93-
def teardown(self) -> None:
94-
"""Frees memory of collected predictions."""
91+
predictions = self.predictions
92+
all_batch_indices = self._all_batch_indices
93+
# free memory
9594
self.predictions = []
9695
self._all_batch_indices = []
96+
return predictions, all_batch_indices
9797

9898
def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
9999
"""Runs the actual predict step together with all the

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,16 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
208208
self._on_train_epoch_end_hook(processed_outputs)
209209
self.trainer.call_hook('on_epoch_end')
210210
self.trainer.logger_connector.on_epoch_end()
211-
return self._epoch_output
211+
212+
epoch_output = self._epoch_output
213+
# free memory
214+
self._epoch_output = None
215+
return epoch_output
212216

213217
def teardown(self) -> None:
214-
"""Frees memory of tracked epoch outputs."""
215-
self.epoch_output = None
218+
self._results.cpu()
219+
self.batch_loop.teardown()
220+
self.val_loop.teardown()
216221

217222
def _run_validation(self):
218223
# reload dataloaders

pytorch_lightning/loops/fit_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def on_advance_end(self) -> None:
236236
self.global_step += 1
237237

238238
def on_run_end(self) -> None:
239-
"""Runs teardown logic and calls the ``on_train_end`` hook"""
239+
"""Calls the ``on_train_end`` hook"""
240240
# NOTE: the iteration_count/current_epoch is already incremented
241241
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
242242
# To simulate that current behavior, we decrement here.
@@ -265,9 +265,6 @@ def on_run_end(self) -> None:
265265
# give accelerators a chance to finish
266266
self.trainer.accelerator.on_train_end()
267267

268-
# reset bookkeeping
269-
self.trainer._running_stage = None
270-
271268
def should_accumulate(self) -> bool:
272269
"""Whether the gradients should be accumulated"""
273270
return self.epoch_loop.batch_loop.should_accumulate()
@@ -291,3 +288,6 @@ def state_dict(self) -> Dict:
291288

292289
def load_state_dict(self, state_dict: Dict) -> None:
293290
self.epoch_loop.load_state_dict(state_dict["epoch_loop"])
291+
292+
def teardown(self) -> None:
293+
self.epoch_loop.teardown()

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,3 @@ def progress_bar_metrics(self) -> Dict[str, float]:
312312
metrics = self.metrics[MetricSource.PBAR]
313313
self._progress_bar_metrics.update(metrics)
314314
return self._progress_bar_metrics
315-
316-
def teardown(self):
317-
self.trainer.fit_loop.epoch_loop._results.cpu()
318-
self.trainer.fit_loop.epoch_loop.val_loop._results.cpu()
319-
self.trainer.validate_loop._results.cpu()
320-
self.trainer.test_loop._results.cpu()

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -900,8 +900,10 @@ def _pre_dispatch(self):
900900

901901
def _post_dispatch(self):
902902
self.accelerator.post_dispatch(self)
903+
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
904+
# which need to happen before.
903905
self.accelerator.teardown()
904-
self.logger_connector.teardown()
906+
self._active_loop.teardown()
905907

906908
def _dispatch(self):
907909
if self.evaluating:
@@ -977,7 +979,6 @@ def _run_train(self) -> None:
977979
self.on_keyboard_interrupt()
978980
# same treatment as below
979981
self.accelerator.on_train_end()
980-
self.state.stage = None
981982
except BaseException:
982983
self.state.status = TrainerStatus.INTERRUPTED
983984
if distributed_available() and self.world_size > 1:

0 commit comments

Comments
 (0)