From f1aa9292b2acafe735f2bfc9a1792ad4ef3600e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 15 Feb 2021 00:18:23 +0100 Subject: [PATCH 1/2] on train end --- pytorch_lightning/accelerators/gpu.py | 1 + pytorch_lightning/trainer/training_loop.py | 8 +------- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 9ec6ad5cdee75..d1c0ee699c53b 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -29,6 +29,7 @@ def on_train_end(self): # clean up memory with torch.cuda.device(self.root_device): torch.cuda.empty_cache() + self.model.cpu() @staticmethod def set_nvidia_flags(): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f727a15310a84..f7f44625a3062 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -148,13 +148,7 @@ def on_train_end(self): self.trainer.profiler.describe() # give accelerators a chance to finish - self.trainer.accelerator_backend.on_train_end() - - # clear mem - if self.trainer._device_type == DeviceType.GPU: - model = self.trainer.get_model() - model.cpu() - torch.cuda.empty_cache() + self.trainer.accelerator.on_train_end() def check_checkpoint_callback(self, should_update, is_last=False): # TODO bake this logic into the ModelCheckpoint callback From 47ec3fdd8dba763d9df62470e1ac9ed1a88b592b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 15 Feb 2021 00:24:38 +0100 Subject: [PATCH 2/2] switch order --- pytorch_lightning/accelerators/gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index d1c0ee699c53b..53f9388d83597 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -27,9 +27,9 @@ def on_train_start(self): def on_train_end(self): # clean up memory + self.model.cpu() with torch.cuda.device(self.root_device): torch.cuda.empty_cache() - self.model.cpu() @staticmethod def set_nvidia_flags():