From 7b616b06dc8a7fa3dfbd639fe2d63392c3e36b29 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 8 Sep 2021 19:43:11 +0530 Subject: [PATCH 1/5] Add remove_checkpoint to CheckpointIO plugin to simplify ModelCheckpoint Callback --- .../callbacks/model_checkpoint.py | 20 +++++-------------- .../plugins/io/checkpoint_plugin.py | 8 ++++++++ pytorch_lightning/plugins/io/torch_plugin.py | 14 +++++++++++++ .../training_type/training_type_plugin.py | 9 +++++++++ 4 files changed, 36 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index cb4ef37b76363..7485937baaa9a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -486,11 +486,6 @@ def __init_triggers( def every_n_epochs(self) -> Optional[int]: return self._every_n_epochs - def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None: - if trainer.should_rank_save_checkpoint and self._fs.exists(filepath): - self._fs.rm(filepath, recursive=True) - log.debug(f"Removed checkpoint: {filepath}") - def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None: # make paths if trainer.should_rank_save_checkpoint: @@ -673,8 +668,8 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ self._save_model(trainer, filepath) - if self.last_model_path and self.last_model_path != filepath and trainer.should_rank_save_checkpoint: - self._del_model(trainer, self.last_model_path) + if self.last_model_path and self.last_model_path != filepath: + trainer.training_type_plugin.remove_checkpoint(trainer, self.last_model_path) self.last_model_path = filepath @@ -698,13 +693,8 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) self._save_model(trainer, filepath) - if ( - self.save_top_k == 1 - and self.best_model_path - and self.best_model_path != filepath - and trainer.should_rank_save_checkpoint - ): - self._del_model(trainer, self.best_model_path) + if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath: + trainer.training_type_plugin.remove_checkpoint(trainer, self.best_model_path) self.best_model_path = filepath @@ -751,7 +741,7 @@ def _update_best_and_save( self._save_model(trainer, filepath) if del_filepath is not None and filepath != del_filepath: - self._del_model(trainer, del_filepath) + trainer.training_type_plugin.remove_checkpoint(trainer, del_filepath) def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML diff --git a/pytorch_lightning/plugins/io/checkpoint_plugin.py b/pytorch_lightning/plugins/io/checkpoint_plugin.py index 506936bc347e6..b10607f050744 100644 --- a/pytorch_lightning/plugins/io/checkpoint_plugin.py +++ b/pytorch_lightning/plugins/io/checkpoint_plugin.py @@ -52,3 +52,11 @@ def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Returns: The loaded checkpoint. """ + + @abstractmethod + def remove_checkpoint(self, path: _PATH) -> None: + """Remove checkpoint filepath from the filesystem. + + Args: + path: Path to checkpoint + """ diff --git a/pytorch_lightning/plugins/io/torch_plugin.py b/pytorch_lightning/plugins/io/torch_plugin.py index be377cb39c3da..90d5eb2f8d62b 100644 --- a/pytorch_lightning/plugins/io/torch_plugin.py +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Any, Callable, Dict, Optional import pytorch_lightning as pl @@ -20,6 +21,8 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.types import _PATH +log = logging.getLogger(__name__) + class TorchCheckpointIO(CheckpointIO): """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints @@ -60,3 +63,14 @@ def load_checkpoint( raise FileNotFoundError(f"Checkpoint at {path} not found. Aborting training.") return pl_load(path, map_location=map_location) + + def remove_checkpoint(self, path: _PATH) -> None: + """Remove checkpoint file from the filesystem. + + Args: + path: Path to checkpoint + """ + fs = get_filesystem(path) + if fs.exists(path): + fs.rm(path, recursive=True) + log.debug(f"Removed checkpoint: {path}") diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 8d35b130eac4d..29d9944a3f33d 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -269,6 +269,15 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: if self.should_rank_save_checkpoint: return self.checkpoint_io.save_checkpoint(checkpoint, filepath) + def remove_checkpoint(self, filepath: str) -> None: + """Remove checkpoint filepath from the filesystem. + + Args: + filepath: Path to checkpoint + """ + if self.should_rank_save_checkpoint: + return self.checkpoint_io.remove_checkpoint(filepath) + @contextlib.contextmanager def model_sharded_context(self) -> Generator: """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to From d3ed5de76689c49532dd43e94277f552b6f3f78f Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 8 Sep 2021 21:47:42 +0530 Subject: [PATCH 2/5] Update save_model in ModelCheckpoint --- .../callbacks/model_checkpoint.py | 20 ++++++------------- pytorch_lightning/plugins/io/torch_plugin.py | 3 +++ 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 7485937baaa9a..42cd078d2179d 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -486,14 +486,6 @@ def __init_triggers( def every_n_epochs(self) -> Optional[int]: return self._every_n_epochs - def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None: - # make paths - if trainer.should_rank_save_checkpoint: - self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) - - # delegate the saving to the trainer - trainer.save_checkpoint(filepath, self.save_weights_only) - def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Tensor] = None) -> bool: if current is None: return False @@ -666,10 +658,10 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ filepath = self._format_checkpoint_name(self.CHECKPOINT_NAME_LAST, monitor_candidates) filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}") - self._save_model(trainer, filepath) + trainer.save_checkpoint(filepath, self.save_weights_only) if self.last_model_path and self.last_model_path != filepath: - trainer.training_type_plugin.remove_checkpoint(trainer, self.last_model_path) + trainer.training_type_plugin.remove_checkpoint(self.last_model_path) self.last_model_path = filepath @@ -691,10 +683,10 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate return filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) - self._save_model(trainer, filepath) + trainer.save_checkpoint(filepath, self.save_weights_only) if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath: - trainer.training_type_plugin.remove_checkpoint(trainer, self.best_model_path) + trainer.training_type_plugin.remove_checkpoint(self.best_model_path) self.best_model_path = filepath @@ -738,10 +730,10 @@ def _update_best_and_save( f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' ) - self._save_model(trainer, filepath) + trainer.save_checkpoint(filepath, self.save_weights_only) if del_filepath is not None and filepath != del_filepath: - trainer.training_type_plugin.remove_checkpoint(trainer, del_filepath) + trainer.training_type_plugin.remove_checkpoint(del_filepath) def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML diff --git a/pytorch_lightning/plugins/io/torch_plugin.py b/pytorch_lightning/plugins/io/torch_plugin.py index 90d5eb2f8d62b..4413afc5d4166 100644 --- a/pytorch_lightning/plugins/io/torch_plugin.py +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os from typing import Any, Callable, Dict, Optional import pytorch_lightning as pl @@ -29,6 +30,8 @@ class TorchCheckpointIO(CheckpointIO): respectively, common for most use cases.""" def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + fs = get_filesystem(path) + fs.makedirs(os.path.dirname(path), exist_ok=True) try: # write the checkpoint dictionary on the file atomic_save(checkpoint, path) From 0d786fcbdf63c4cc58a85c87ca05916cb1b6568f Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 9 Sep 2021 18:44:40 +0530 Subject: [PATCH 3/5] Fix tests and update changelog --- CHANGELOG.md | 6 +++++- tests/checkpointing/test_model_checkpoint.py | 14 ++++++++------ tests/plugins/test_checkpoint_io_plugin.py | 17 +++++++++++++---- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4290e4edec64f..7e3b0da856748 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -107,7 +107,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) -- Add a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221)) + +- Added a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221)) + + +- Added `remove_checkpoint` to `CheckpointIO` plugin to simplify `ModelCheckpoint` Callback ([#9373](https://github.com/PyTorchLightning/pytorch-lightning/pull/9373)) ### Changed diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3c70bd6daec66..5d93be192582f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -760,17 +760,18 @@ def test_default_checkpoint_behavior(tmpdir): default_root_dir=tmpdir, max_epochs=3, progress_bar_refresh_rate=0, limit_train_batches=5, limit_val_batches=5 ) - with patch.object(ModelCheckpoint, "_save_model", wraps=trainer.checkpoint_callback._save_model) as save_mock: + with patch.object(trainer, "save_checkpoint", wraps=trainer.save_checkpoint) as save_mock: trainer.fit(model) results = trainer.test() assert len(results) == 1 save_dir = tmpdir / "lightning_logs" / "version_0" / "checkpoints" + save_weights_only = trainer.checkpoint_callback.save_weights_only save_mock.assert_has_calls( [ - call(trainer, save_dir / "epoch=0-step=4.ckpt"), - call(trainer, save_dir / "epoch=1-step=9.ckpt"), - call(trainer, save_dir / "epoch=2-step=14.ckpt"), + call(save_dir / "epoch=0-step=4.ckpt", save_weights_only), + call(save_dir / "epoch=1-step=9.ckpt", save_weights_only), + call(save_dir / "epoch=2-step=14.ckpt", save_weights_only), ] ) ckpts = os.listdir(save_dir) @@ -852,7 +853,6 @@ def validation_epoch_end(self, outputs): model = CurrentModel() callback = ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir) - callback._save_model = MagicMock() trainer = Trainer( callbacks=[callback], @@ -860,10 +860,12 @@ def validation_epoch_end(self, outputs): val_check_interval=1.0, max_epochs=len(monitor), ) + trainer.save_checkpoint = MagicMock() + trainer.fit(model) # check that last one is also the best one - assert callback._save_model.call_count == len(monitor) + assert trainer.save_checkpoint.call_count == len(monitor) assert mode == "min" and callback.best_model_score == 5 or mode == "max" and callback.best_model_score == 8 diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index 2b0195d584de7..02d6780c4e9bc 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Any, Dict, Optional from unittest.mock import MagicMock @@ -33,6 +34,9 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: return torch.load(path) + def remove_checkpoint(self, path: _PATH) -> None: + os.remove(path) + def test_checkpoint_plugin_called(tmpdir): """Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when saving/loading.""" @@ -47,10 +51,13 @@ def test_checkpoint_plugin_called(tmpdir): default_root_dir=tmpdir, plugins=SingleDevicePlugin(device, checkpoint_io=checkpoint_plugin), callbacks=ck, - max_epochs=1, + max_epochs=2, ) trainer.fit(model) - assert checkpoint_plugin.save_checkpoint.call_count == 3 + + assert checkpoint_plugin.save_checkpoint.call_count == 5 + assert checkpoint_plugin.remove_checkpoint.call_count == 1 + trainer.test(model, ckpt_path=ck.last_model_path) checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt") @@ -63,10 +70,12 @@ def test_checkpoint_plugin_called(tmpdir): default_root_dir=tmpdir, plugins=[SingleDevicePlugin(device), checkpoint_plugin], callbacks=ck, - max_epochs=1, + max_epochs=2, ) trainer.fit(model) - assert checkpoint_plugin.save_checkpoint.call_count == 3 + + assert checkpoint_plugin.save_checkpoint.call_count == 5 + assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) checkpoint_plugin.load_checkpoint.assert_called_once() From 12e4e69e1b595e0c41ab65af9fbd1c872d32fc5d Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 9 Sep 2021 19:55:05 +0530 Subject: [PATCH 4/5] Update CHANGELOG.md Co-authored-by: ananthsub --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 732916757db14..a50ba0d552f4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,7 +111,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221)) -- Added `inference_mode` for evaluation and prediction ([8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813)) +- Added `inference_mode` for evaluation and prediction ([#8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813)) - Added `remove_checkpoint` to `CheckpointIO` plugin by moving the responsibility from `ModelCheckpoint` Callback ([#9373](https://github.com/PyTorchLightning/pytorch-lightning/pull/9373)) From fc4f6f8258f49ee5d15a2b87f7d4e79e31d1cf61 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 10 Sep 2021 14:26:38 +0530 Subject: [PATCH 5/5] Update pytorch_lightning/plugins/io/checkpoint_plugin.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/plugins/io/checkpoint_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/io/checkpoint_plugin.py b/pytorch_lightning/plugins/io/checkpoint_plugin.py index b10607f050744..0d038d45a124b 100644 --- a/pytorch_lightning/plugins/io/checkpoint_plugin.py +++ b/pytorch_lightning/plugins/io/checkpoint_plugin.py @@ -55,7 +55,7 @@ def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> @abstractmethod def remove_checkpoint(self, path: _PATH) -> None: - """Remove checkpoint filepath from the filesystem. + """Remove checkpoint file from the filesystem. Args: path: Path to checkpoint