diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 0eca72095e0e0..8f8a517d544f0 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -395,6 +395,12 @@ def detach(self): if isinstance(v, torch.Tensor): self.__setitem__(k, v.detach()) + def cpu(self): + """Move all self attributes to CPU.""" + for k, v in self.items(): + if isinstance(v, torch.Tensor): + self.__setitem__(k, v.cpu()) + def __repr__(self): self_copy = self.copy() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 2980b037c95f7..9f8d029d9bef4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -392,6 +392,10 @@ def cache_result(self) -> None: # attach capture batch_size Result.attach_batch_size(self._batch_size, hook_result) + hook_result.detach() + if self.trainer.move_metrics_to_cpu: + hook_result.cpu() + self._internals[fx_name].append( hook_result, dataloader_idx=dataloader_idx, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 946064660f818..6a6a3229b8061 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -93,7 +93,7 @@ def cache_logged_metrics(self) -> Union[EpochResultStore, None]: if self._current_stage is not None: self._cached_results[self._current_stage].cache_result() - def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps): + def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging self.configure_logger(logger) # todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders @@ -101,6 +101,8 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps): self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps self.trainer.log_every_n_steps = log_every_n_steps + self.trainer.move_metrics_to_cpu = move_metrics_to_cpu + @property def should_flush_logs(self): should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2d4e2c0d9e4bd..4ef83dc7de544 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -60,6 +60,7 @@ from pytorch_lightning.plugins.plugin_connector import PluginConnector from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu_accelerator import CPUAccelerator +from pytorch_lightning.utilities.memory import recursive_detach # warnings to ignore in trainer warnings.filterwarnings( @@ -135,6 +136,7 @@ def __init__( amp_level: str = 'O2', distributed_backend: Optional[str] = None, automatic_optimization: bool = True, + move_metrics_to_cpu: bool = False, ): r""" Customize every aspect of training via flags @@ -272,6 +274,9 @@ def __init__( stored in a different place than the logs written in `default_root_dir`. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' Defaults to `default_root_dir`. + + move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu. + This can save some gpu memory, but can make training slower. Use with attention. """ super().__init__() @@ -363,7 +368,12 @@ def __init__( self.profile_connector.on_trainer_init(profiler) # init logger flags - self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps) + self.logger_connector.on_trainer_init( + logger, + flush_logs_every_n_steps, + log_every_n_steps, + move_metrics_to_cpu + ) # init debugging flags self.debugging_connector.on_init_start( @@ -603,12 +613,11 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # log step metrics step_metrics = self.evaluation_loop.log_evaluation_step_metrics(batch, batch_idx) - if step_metrics is not None: - dl_step_metrics.append(step_metrics) + # track epoch level outputs + dl_step_metrics = self.track_output_for_epoch_end(dl_step_metrics, step_metrics) # track epoch level outputs - if output is not None: - dl_outputs.append(output) + dl_outputs = self.track_output_for_epoch_end(dl_outputs, output) self.evaluation_loop.outputs.append(dl_outputs) self.evaluation_loop.step_metrics.append(dl_step_metrics) @@ -634,6 +643,19 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): return eval_loop_results, deprecated_eval_results + def track_output_for_epoch_end(self, outputs, output): + if output is not None: + if isinstance(output, Result): + output.detach() + if self.move_metrics_to_cpu: + output.cpu() + elif isinstance(output, dict): + output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu) + elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu: + output = output.cpu() + outputs.append(output) + return outputs + def run_test(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1cf06c3709e7e..f705d82868da7 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -434,6 +434,8 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch): # track metrics without grads for epoch reduction training_step_output_for_epoch_end = copy(result) training_step_output_for_epoch_end.detach() + if self.trainer.move_metrics_to_cpu: + training_step_output_for_epoch_end.cpu() # what flows back into the system training_step_output = result diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 1d3b8d27807f0..16c0ede1e5413 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -17,7 +17,7 @@ import torch -def recursive_detach(in_dict: dict) -> dict: +def recursive_detach(in_dict: dict, to_cpu: bool = False) -> dict: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries @@ -26,6 +26,7 @@ def recursive_detach(in_dict: dict) -> dict: Args: in_dict: + to_cpu: Wheter to move tensor to cpu Return: out_dict: @@ -35,7 +36,11 @@ def recursive_detach(in_dict: dict) -> dict: if isinstance(v, dict): out_dict.update({k: recursive_detach(v)}) elif callable(getattr(v, 'detach', None)): - out_dict.update({k: v.detach()}) + # detach + v = v.detach() + if to_cpu: + v = v.cpu() + out_dict.update({k: v}) else: out_dict.update({k: v}) return out_dict