Skip to content

Commit 8c32bf2

Browse files
authored
refactor on_gpu handling in checkpoint connector (#7860)
1 parent acd38dd commit 8c32bf2

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def restore_weights(self) -> None:
6060

6161
# 2. Attempt to restore states from `resume_from_checkpoint` file
6262
elif self.trainer.resume_from_checkpoint is not None:
63-
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU)
63+
self.restore(self.trainer.resume_from_checkpoint)
6464

6565
# wait for all to catch up
6666
self.trainer.training_type_plugin.barrier('TrainerIOMixin.restore_weights')
@@ -69,7 +69,7 @@ def restore_weights(self) -> None:
6969
if self.trainer._device_type == DeviceType.GPU:
7070
torch.cuda.empty_cache()
7171

72-
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
72+
def restore(self, checkpoint_path: str) -> bool:
7373
"""
7474
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
7575
All restored states are listed in return value description of `dump_checkpoint`.
@@ -85,7 +85,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
8585

8686
model = self.trainer.lightning_module
8787

88-
if on_gpu:
88+
if self.trainer._device_type == DeviceType.GPU:
8989
model.cuda(self.trainer.root_gpu)
9090

9191
# restore training state

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytorch_lightning as pl
1919
from pytorch_lightning.loggers.base import DummyLogger
20-
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
20+
from pytorch_lightning.utilities import rank_zero_warn
2121
from pytorch_lightning.utilities.cloud_io import get_filesystem
2222
from pytorch_lightning.utilities.data import has_len
2323
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -83,7 +83,7 @@ def scale_batch_size(
8383

8484
# Restore initial state of model
8585
if trainer.is_global_zero:
86-
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
86+
trainer.checkpoint_connector.restore(str(save_path))
8787
fs = get_filesystem(str(save_path))
8888
if fs.exists(save_path):
8989
fs.rm(save_path)

pytorch_lightning/tuner/lr_finder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pytorch_lightning as pl
2626
from pytorch_lightning.callbacks import Callback
2727
from pytorch_lightning.loggers.base import DummyLogger
28-
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
28+
from pytorch_lightning.utilities import rank_zero_warn
2929
from pytorch_lightning.utilities.cloud_io import get_filesystem
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3131
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
@@ -259,7 +259,7 @@ def lr_find(
259259

260260
# Reset model state
261261
if trainer.is_global_zero:
262-
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
262+
trainer.checkpoint_connector.restore(str(save_path))
263263
fs = get_filesystem(str(save_path))
264264
if fs.exists(save_path):
265265
fs.rm(save_path)

0 commit comments

Comments
 (0)