Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065))
* Refactored prediction loop interface; added new classes `PredictionLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077))
* Removed `pytorch_lightning/trainer/predict_loop.py` ([#8094](https://github.com/PyTorchLightning/pytorch-lightning/pull/8094))
* Moved result teardown to the loops ([#8245](https://github.com/PyTorchLightning/pytorch-lightning/pull/8245))


- Refactored logging
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
break

output = self.on_run_end()
self.teardown()
return output

@abstractmethod
Expand Down Expand Up @@ -132,7 +131,7 @@ def on_run_end(self) -> Any:
"""Hook to be called at the end of the run. Its return argument is returned from :attr:`run`."""

def teardown(self) -> None:
"""The very last method called inside :meth:`run`. Use to release memory etc."""
"""Use to release memory etc."""

def load_state_dict(self, state_dict: Dict) -> None:
"""Restore the loop state from the provided state_dict."""
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,7 @@ def on_evaluation_epoch_end(self) -> None:
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

def teardown(self) -> None:
self._results.cpu()
self.epoch_loop.teardown()
7 changes: 3 additions & 4 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,10 @@ def advance(

def on_run_end(self) -> List[STEP_OUTPUT]:
"""Returns the outputs of the whole run"""
return self.outputs

def teardown(self) -> None:
"""Frees memory of tracked outputs"""
outputs = self.outputs
# free memory
self.outputs = []
return outputs

def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]:
"""The evaluation step (validation_step or test_step depending on the trainer's state).
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def advance(

def on_run_end(self) -> Tuple[Any, Any]:
"""Returns the predictions and the corresponding batch indices"""
return self.predictions, self._all_batch_indices

def teardown(self) -> None:
"""Frees memory of collected predictions."""
predictions = self.predictions
all_batch_indices = self._all_batch_indices
# free memory
self.predictions = []
self._all_batch_indices = []
return predictions, all_batch_indices

def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""Runs the actual predict step together with all the
Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,16 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
self._on_train_epoch_end_hook(processed_outputs)
self.trainer.call_hook('on_epoch_end')
self.trainer.logger_connector.on_epoch_end()
return self._epoch_output

epoch_output = self._epoch_output
# free memory
self._epoch_output = None
return epoch_output

def teardown(self) -> None:
"""Frees memory of tracked epoch outputs."""
self.epoch_output = None
self._results.cpu()
self.batch_loop.teardown()
self.val_loop.teardown()

def _run_validation(self):
# reload dataloaders
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def on_advance_end(self) -> None:
self.global_step += 1

def on_run_end(self) -> None:
"""Runs teardown logic and calls the ``on_train_end`` hook"""
"""Calls the ``on_train_end`` hook"""
# NOTE: the iteration_count/current_epoch is already incremented
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
# To simulate that current behavior, we decrement here.
Expand Down Expand Up @@ -265,9 +265,6 @@ def on_run_end(self) -> None:
# give accelerators a chance to finish
self.trainer.accelerator.on_train_end()

# reset bookkeeping
self.trainer._running_stage = None

def should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated"""
return self.epoch_loop.batch_loop.should_accumulate()
Expand All @@ -291,3 +288,6 @@ def state_dict(self) -> Dict:

def load_state_dict(self, state_dict: Dict) -> None:
self.epoch_loop.load_state_dict(state_dict["epoch_loop"])

def teardown(self) -> None:
self.epoch_loop.teardown()
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,3 @@ def progress_bar_metrics(self) -> Dict[str, float]:
metrics = self.metrics[MetricSource.PBAR]
self._progress_bar_metrics.update(metrics)
return self._progress_bar_metrics

def teardown(self):
self.trainer.fit_loop.epoch_loop._results.cpu()
self.trainer.fit_loop.epoch_loop.val_loop._results.cpu()
self.trainer.validate_loop._results.cpu()
self.trainer.test_loop._results.cpu()
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,10 @@ def _pre_dispatch(self):

def _post_dispatch(self):
self.accelerator.post_dispatch(self)
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.accelerator.teardown()
self.logger_connector.teardown()
self._active_loop.teardown()

def _dispatch(self):
if self.evaluating:
Expand Down Expand Up @@ -977,7 +979,6 @@ def _run_train(self) -> None:
self.on_keyboard_interrupt()
# same treatment as below
self.accelerator.on_train_end()
self.state.stage = None
except BaseException:
self.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and self.world_size > 1:
Expand Down