Skip to content

Commit b6e2fbd

Browse files
authored
Merge 72ad287 into ae4dca9
2 parents ae4dca9 + 72ad287 commit b6e2fbd

File tree

2 files changed

+2
-7
lines changed

2 files changed

+2
-7
lines changed

pytorch_lightning/accelerators/gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def on_train_start(self):
2727

2828
def on_train_end(self):
2929
# clean up memory
30+
self.model.cpu()
3031
with torch.cuda.device(self.root_device):
3132
torch.cuda.empty_cache()
3233

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,7 @@ def on_train_end(self):
148148
self.trainer.profiler.describe()
149149

150150
# give accelerators a chance to finish
151-
self.trainer.accelerator_backend.on_train_end()
152-
153-
# clear mem
154-
if self.trainer._device_type == DeviceType.GPU:
155-
model = self.trainer.get_model()
156-
model.cpu()
157-
torch.cuda.empty_cache()
151+
self.trainer.accelerator.on_train_end()
158152

159153
def check_checkpoint_callback(self, should_update, is_last=False):
160154
# TODO bake this logic into the ModelCheckpoint callback

0 commit comments

Comments
 (0)