Skip to content

Commit 56396ab

Browse files
authored
fix checkpointing to remote file paths (#2925)
1 parent d13e5c9 commit 56396ab

File tree

5 files changed

+37
-15
lines changed

5 files changed

+37
-15
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytorch_lightning import _logger as log
1717
from pytorch_lightning.callbacks.base import Callback
1818
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
19-
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
19+
from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path
2020

2121

2222
class ModelCheckpoint(Callback):
@@ -122,10 +122,10 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
122122
if gfile.isdir(filepath):
123123
self.dirpath, self.filename = filepath, '{epoch}'
124124
else:
125-
filepath = os.path.realpath(filepath)
125+
if not is_remote_path(filepath): # dont normalize remote paths
126+
filepath = os.path.realpath(filepath)
126127
self.dirpath, self.filename = os.path.split(filepath)
127-
if not gfile.exists(self.dirpath):
128-
makedirs(self.dirpath)
128+
makedirs(self.dirpath) # calls with exist_ok
129129
self.save_last = save_last
130130
self.save_top_k = save_top_k
131131
self.save_weights_only = save_weights_only
@@ -174,7 +174,12 @@ def _del_model(self, filepath):
174174
# dependencies exist then this will work fine.
175175
gfile.remove(filepath)
176176
except AttributeError:
177-
os.remove(filepath)
177+
if is_remote_path(filepath):
178+
log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode."
179+
" Please install tensorflow to run gfile in full mode"
180+
" if writing checkpoints to remote locations")
181+
else:
182+
os.remove(filepath)
178183

179184
def _save_model(self, filepath, trainer, pl_module):
180185

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def train_fx(trial_hparams, cluster_manager, _):
127127
128128
"""
129129

130+
import io
130131
import os
131132
import re
132133
from abc import ABC, abstractmethod
@@ -146,6 +147,7 @@ def train_fx(trial_hparams, cluster_manager, _):
146147
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
147148
from pytorch_lightning.core.datamodule import LightningDataModule
148149
from pytorch_lightning.core.lightning import LightningModule
150+
from pytorch_lightning.utilities.cloud_io import cloud_open
149151

150152

151153
try:
@@ -435,10 +437,13 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
435437
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
436438
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
437439
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
440+
bytesbuffer = io.BytesIO()
438441
if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
439-
torch.save(model.state_dict(), last_path, _use_new_zipfile_serialization=False)
442+
torch.save(model.state_dict(), bytesbuffer, _use_new_zipfile_serialization=False)
440443
else:
441-
torch.save(model.state_dict(), last_path)
444+
torch.save(model.state_dict(), bytesbuffer)
445+
with cloud_open(last_path, 'wb') as f:
446+
f.write(bytesbuffer.getvalue())
442447
mp_queue.put(last_path)
443448

444449
def save_spawn_weights(self, model):

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
5454
from pytorch_lightning.utilities.debugging import InternalDebugger
5555
from pytorch_lightning.utilities.exceptions import MisconfigurationException
56+
from pytorch_lightning.utilities.cloud_io import is_remote_path
5657

5758
# warnings to ignore in trainer
5859
warnings.filterwarnings(
@@ -880,7 +881,7 @@ def default_root_dir(self) -> str:
880881
The default location to save artifacts of loggers, checkpoints etc.
881882
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
882883
"""
883-
if "://" in str(self._default_root_dir):
884+
if is_remote_path(self._default_root_dir):
884885
# it is a remote uri, use as is
885886
return self._default_root_dir
886887
return os.path.normpath(self._default_root_dir)
@@ -891,7 +892,7 @@ def weights_save_path(self) -> str:
891892
The default root location to save weights (checkpoints), e.g., when the
892893
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
893894
"""
894-
if "://" in str(self._weights_save_path):
895+
if is_remote_path(self._weights_save_path):
895896
# it is a remote uri, use as is
896897
return self._weights_save_path
897898
return os.path.normpath(self._weights_save_path)

pytorch_lightning/trainer/training_io.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
8484
"""
8585

86+
import io
8687
import os
8788
import re
8889
import signal
@@ -104,7 +105,7 @@
104105
)
105106
from pytorch_lightning.utilities import rank_zero_warn, AMPType
106107
from pytorch_lightning.utilities.cloud_io import load as pl_load
107-
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
108+
from pytorch_lightning.utilities.cloud_io import cloud_open, gfile, makedirs
108109

109110
try:
110111
import torch_xla
@@ -269,15 +270,16 @@ def _atomic_save(self, checkpoint, filepath: str):
269270
filepath: The path to which the checkpoint will be saved.
270271
This points to the file that the checkpoint will be stored in.
271272
"""
272-
tmp_path = str(filepath) + ".part"
273+
bytesbuffer = io.BytesIO()
273274
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
274275
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
275276
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
276277
if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
277-
torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
278+
torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False)
278279
else:
279-
torch.save(checkpoint, tmp_path)
280-
os.replace(tmp_path, filepath)
280+
torch.save(checkpoint, bytesbuffer)
281+
with cloud_open(filepath, 'wb') as f:
282+
f.write(bytesbuffer.getvalue())
281283

282284
def save_checkpoint(self, filepath, weights_only: bool = False):
283285
checkpoint = self.dump_checkpoint(weights_only)

pytorch_lightning/utilities/cloud_io.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ def load(path_or_url: str, map_location=None):
2828
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
2929

3030

31+
def is_remote_path(path: pathlike):
32+
"""Determine if a path is a local path or a remote path like s3://bucket/path
33+
34+
This should catch paths like s3:// hdfs:// and gcs://
35+
"""
36+
return "://" in str(path)
37+
38+
3139
def modern_gfile():
3240
"""Check the version number of tensorboard.
3341
@@ -61,6 +69,7 @@ def cloud_open(path: pathlike, mode: str, newline: str = None):
6169

6270
def makedirs(path: pathlike):
6371
if hasattr(gfile, "makedirs") and modern_gfile():
64-
return gfile.makedirs(str(path))
72+
if not gfile.exists(str(path)):
73+
return gfile.makedirs(str(path))
6574
# otherwise minimal dependencies are installed and only local files will work
6675
return os.makedirs(path, exist_ok=True)

0 commit comments

Comments
 (0)