|
16 | 16 | from pytorch_lightning import _logger as log |
17 | 17 | from pytorch_lightning.callbacks.base import Callback |
18 | 18 | from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only |
19 | | -from pytorch_lightning.utilities.cloud_io import gfile, makedirs |
| 19 | +from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path |
20 | 20 |
|
21 | 21 |
|
22 | 22 | class ModelCheckpoint(Callback): |
@@ -122,10 +122,10 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve |
122 | 122 | if gfile.isdir(filepath): |
123 | 123 | self.dirpath, self.filename = filepath, '{epoch}' |
124 | 124 | else: |
125 | | - filepath = os.path.realpath(filepath) |
| 125 | + if not is_remote_path(filepath): # dont normalize remote paths |
| 126 | + filepath = os.path.realpath(filepath) |
126 | 127 | self.dirpath, self.filename = os.path.split(filepath) |
127 | | - if not gfile.exists(self.dirpath): |
128 | | - makedirs(self.dirpath) |
| 128 | + makedirs(self.dirpath) # calls with exist_ok |
129 | 129 | self.save_last = save_last |
130 | 130 | self.save_top_k = save_top_k |
131 | 131 | self.save_weights_only = save_weights_only |
@@ -174,7 +174,12 @@ def _del_model(self, filepath): |
174 | 174 | # dependencies exist then this will work fine. |
175 | 175 | gfile.remove(filepath) |
176 | 176 | except AttributeError: |
177 | | - os.remove(filepath) |
| 177 | + if is_remote_path(filepath): |
| 178 | + log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode." |
| 179 | + " Please install tensorflow to run gfile in full mode" |
| 180 | + " if writing checkpoints to remote locations") |
| 181 | + else: |
| 182 | + os.remove(filepath) |
178 | 183 |
|
179 | 184 | def _save_model(self, filepath, trainer, pl_module): |
180 | 185 |
|
|
0 commit comments