diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f2c0ab2248d6..2600e1412902b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -167,6 +167,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065)) * Refactored prediction loop interface; added new classes `PredictionLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077)) * Removed `pytorch_lightning/trainer/predict_loop.py` ([#8094](https://github.com/PyTorchLightning/pytorch-lightning/pull/8094)) + * Moved result teardown to the loops ([#8245](https://github.com/PyTorchLightning/pytorch-lightning/pull/8245)) - Refactored logging diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 1edc997e715ce..e57078fd3950d 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -100,7 +100,6 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: break output = self.on_run_end() - self.teardown() return output @abstractmethod @@ -132,7 +131,7 @@ def on_run_end(self) -> Any: """Hook to be called at the end of the run. Its return argument is returned from :attr:`run`.""" def teardown(self) -> None: - """The very last method called inside :meth:`run`. Use to release memory etc.""" + """Use to release memory etc.""" def load_state_dict(self, state_dict: Dict) -> None: """Restore the loop state from the provided state_dict.""" diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index ca0118ba0d63e..02d802fb3fc15 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -263,3 +263,7 @@ def on_evaluation_epoch_end(self) -> None: self.trainer.call_hook(hook_name) self.trainer.call_hook("on_epoch_end") self.trainer.logger_connector.on_epoch_end() + + def teardown(self) -> None: + self._results.cpu() + self.epoch_loop.teardown() diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index d42a8941630a1..7f8ef06d7687f 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -119,11 +119,10 @@ def advance( def on_run_end(self) -> List[STEP_OUTPUT]: """Returns the outputs of the whole run""" - return self.outputs - - def teardown(self) -> None: - """Frees memory of tracked outputs""" + outputs = self.outputs + # free memory self.outputs = [] + return outputs def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: """The evaluation step (validation_step or test_step depending on the trainer's state). diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 258a81648a3e0..29a76793b4648 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -88,12 +88,12 @@ def advance( def on_run_end(self) -> Tuple[Any, Any]: """Returns the predictions and the corresponding batch indices""" - return self.predictions, self._all_batch_indices - - def teardown(self) -> None: - """Frees memory of collected predictions.""" + predictions = self.predictions + all_batch_indices = self._all_batch_indices + # free memory self.predictions = [] self._all_batch_indices = [] + return predictions, all_batch_indices def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """Runs the actual predict step together with all the diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 61ff4bae3d2a2..3d6fbc8dcdc1c 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -208,11 +208,16 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: self._on_train_epoch_end_hook(processed_outputs) self.trainer.call_hook('on_epoch_end') self.trainer.logger_connector.on_epoch_end() - return self._epoch_output + + epoch_output = self._epoch_output + # free memory + self._epoch_output = None + return epoch_output def teardown(self) -> None: - """Frees memory of tracked epoch outputs.""" - self.epoch_output = None + self._results.cpu() + self.batch_loop.teardown() + self.val_loop.teardown() def _run_validation(self): # reload dataloaders diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 224730977f67a..a7699eaec812c 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -236,7 +236,7 @@ def on_advance_end(self) -> None: self.global_step += 1 def on_run_end(self) -> None: - """Runs teardown logic and calls the ``on_train_end`` hook""" + """Calls the ``on_train_end`` hook""" # NOTE: the iteration_count/current_epoch is already incremented # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit # To simulate that current behavior, we decrement here. @@ -265,9 +265,6 @@ def on_run_end(self) -> None: # give accelerators a chance to finish self.trainer.accelerator.on_train_end() - # reset bookkeeping - self.trainer._running_stage = None - def should_accumulate(self) -> bool: """Whether the gradients should be accumulated""" return self.epoch_loop.batch_loop.should_accumulate() @@ -291,3 +288,6 @@ def state_dict(self) -> Dict: def load_state_dict(self, state_dict: Dict) -> None: self.epoch_loop.load_state_dict(state_dict["epoch_loop"]) + + def teardown(self) -> None: + self.epoch_loop.teardown() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 4f55cd2b5c452..e248b5ff8cf13 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -312,9 +312,3 @@ def progress_bar_metrics(self) -> Dict[str, float]: metrics = self.metrics[MetricSource.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics - - def teardown(self): - self.trainer.fit_loop.epoch_loop._results.cpu() - self.trainer.fit_loop.epoch_loop.val_loop._results.cpu() - self.trainer.validate_loop._results.cpu() - self.trainer.test_loop._results.cpu() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b54487333e9f8..c5a95e45bbf66 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -900,8 +900,10 @@ def _pre_dispatch(self): def _post_dispatch(self): self.accelerator.post_dispatch(self) + # these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns + # which need to happen before. self.accelerator.teardown() - self.logger_connector.teardown() + self._active_loop.teardown() def _dispatch(self): if self.evaluating: @@ -977,7 +979,6 @@ def _run_train(self) -> None: self.on_keyboard_interrupt() # same treatment as below self.accelerator.on_train_end() - self.state.stage = None except BaseException: self.state.status = TrainerStatus.INTERRUPTED if distributed_available() and self.world_size > 1: