From a3f315b089eb26a260764075aa135a404e57eaaf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 27 Nov 2021 04:31:02 +0100 Subject: [PATCH 01/14] Minor changes in preparation for saving the loops state --- pytorch_lightning/loops/base.py | 4 +- pytorch_lightning/trainer/progress.py | 3 -- tests/checkpointing/test_model_checkpoint.py | 41 ++++++++---------- tests/models/test_hooks.py | 45 ++++++-------------- tests/models/test_restore.py | 10 +++-- tests/trainer/test_dataloaders.py | 3 +- 6 files changed, 39 insertions(+), 67 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 38b0d652e5d2f..3aa84962fd7a5 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -262,7 +262,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional elif ( isinstance(v, ResultCollection) and self.trainer is not None - and getattr(self.trainer, "lightning_module", None) is not None + and self.trainer.lightning_module is not None ): metric_attributes = { name: module @@ -278,7 +278,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional # On reload, we need to re-attach the `Metric`s back to the `ResultCollection`. # The references are provided through the `metric_attributes` dictionary. v.load_state_dict( - state_dict[prefix + k], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce + state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce ) if not self.trainer.is_global_zero: diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 7eaf219910b67..f6b136f8d11ec 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -148,9 +148,6 @@ def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int) """Utility function to easily create an instance from keyword arguments to both ``Tracker``s.""" return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs)) - def reset_on_epoch(self) -> None: - self.current.reset() - def reset_on_run(self) -> None: self.current.reset() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 04255d51ad069..f29f6e9159b05 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -864,7 +864,8 @@ def validation_epoch_end(self, outputs): def test_checkpoint_repeated_strategy(tmpdir): - """This test validates that the checkpoint can be called when provided to callbacks list.""" + """This test validates checkpoint can be called several times without increasing internally its global step if + nothing run.""" checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}") class ExtendedBoringModel(BoringModel): @@ -875,34 +876,25 @@ def validation_step(self, batch, batch_idx): model = ExtendedBoringModel() model.validation_epoch_end = None - trainer = Trainer( - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - callbacks=[checkpoint_callback], - enable_progress_bar=False, - enable_model_summary=False, - ) + trainer_kwargs = { + "max_epochs": 1, + "limit_train_batches": 2, + "limit_val_batches": 2, + "limit_test_batches": 2, + "enable_progress_bar": False, + "enable_model_summary": False, + } + trainer = Trainer(**trainer_kwargs, callbacks=[checkpoint_callback]) trainer.fit(model) assert os.listdir(tmpdir) == ["epoch=00.ckpt"] for idx in range(4): # load from checkpoint - model = LogInTwoMethods.load_from_checkpoint(checkpoint_callback.best_model_path) - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - enable_progress_bar=False, - enable_model_summary=False, - ) + trainer = pl.Trainer(**trainer_kwargs, default_root_dir=tmpdir) trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path) - trainer.test(model, verbose=False) - assert set(os.listdir(tmpdir)) == {"epoch=00.ckpt", "lightning_logs"} - assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f"version_{i}" for i in range(4)} + trainer.test(ckpt_path=checkpoint_callback.best_model_path, verbose=False) + assert set(os.listdir(tmpdir)) == {"epoch=00.ckpt", "lightning_logs"} + assert set(os.listdir(tmpdir / "lightning_logs")) == {f"version_{i}" for i in range(4)} def test_checkpoint_repeated_strategy_extended(tmpdir): @@ -935,7 +927,8 @@ def assert_checkpoint_log_dir(idx): lightning_logs = tmpdir / "lightning_logs" actual = [d.basename for d in lightning_logs.listdir(sort=True)] assert actual == [f"version_{i}" for i in range(idx + 1)] - assert len(ckpt_dir.listdir()) == epochs + actual = [d.basename for d in ckpt_dir.listdir()] + assert len(actual) == epochs, actual ckpt_dir = tmpdir / "checkpoints" checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index e8db816ed4edc..f49aefd370bf9 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -594,11 +594,13 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): called = [] model = HookedModel(called) callback = HookedCallback(called) + + # already performed 1 step, resume and do 2 more train_batches = 2 + steps_after_reload = 1 + train_batches trainer = Trainer( default_root_dir=tmpdir, - # already performed 1 step, now resuming to do an additional 2 - max_steps=(1 + train_batches), + max_steps=steps_after_reload, limit_val_batches=0, enable_progress_bar=False, enable_model_summary=False, @@ -609,16 +611,19 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), ] + trainer.fit(model, ckpt_path=best_model_path) - saved_ckpt = { + loaded_ckpt = { "callbacks": ANY, - "epoch": 2, # TODO: wrong saved epoch - "global_step": (1 + train_batches), + "epoch": 1, # TODO: wrong saved epoch, should be 0 + "global_step": 1, "lr_schedulers": ANY, "optimizer_states": ANY, "pytorch-lightning_version": __version__, "state_dict": ANY, } + # TODO: wrong saved epoch, should be 0 + saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 2} expected = [ dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), @@ -627,20 +632,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)), dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")), dict(name="setup", kwargs=dict(stage="fit")), - dict( - name="on_load_checkpoint", - args=( - { - "callbacks": ANY, - "epoch": 1, - "global_step": 1, - "lr_schedulers": ANY, - "optimizer_states": ANY, - "pytorch-lightning_version": __version__, - "state_dict": ANY, - }, - ), - ), + dict(name="on_load_checkpoint", args=(loaded_ckpt,)), dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})), dict(name="configure_sharded_model"), dict(name="Callback.on_configure_sharded_model", args=(trainer, model)), @@ -878,20 +870,7 @@ def call(hook, fn, *args, **kwargs): dict(name="val_dataloader"), dict(name="train_dataloader"), dict(name="val_dataloader"), - dict( - name="on_save_checkpoint", - args=( - { - "callbacks": ANY, - "epoch": 1, - "global_step": 2, - "lr_schedulers": ANY, - "optimizer_states": ANY, - "pytorch-lightning_version": __version__, - "state_dict": ANY, - }, - ), - ), + dict(name="on_save_checkpoint", args=(ANY,)), dict(name="teardown", kwargs=dict(stage="fit")), ] assert called == expected diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 5e7c61163130d..1139e6fb5e8ad 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -84,21 +84,21 @@ class GenericValTestLossBoringModel(GenericParentValTestLossBoringModel[int]): class CustomClassificationModelDP(ClassificationModel): - def _step(self, batch, batch_idx): + def _step(self, batch): x, y = batch logits = self(x) return {"logits": logits, "y": y} def training_step(self, batch, batch_idx): - out = self._step(batch, batch_idx) + out = self._step(batch) loss = F.cross_entropy(out["logits"], out["y"]) return loss def validation_step(self, batch, batch_idx): - return self._step(batch, batch_idx) + return self._step(batch) def test_step(self, batch, batch_idx): - return self._step(batch, batch_idx) + return self._step(batch) def validation_step_end(self, outputs): self.log("val_acc", self.valid_acc(outputs["logits"], outputs["y"])) @@ -142,6 +142,8 @@ def configure_optimizers(self): max_epochs=1, limit_train_batches=2, limit_val_batches=2, + limit_test_batches=2, + limit_predict_batches=2, logger=False, callbacks=[checkpoint_callback], num_sanity_val_steps=0, diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index da7e0704cd4e2..24d1f17a8defc 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -899,10 +899,11 @@ def gen(self): trainer = Trainer( default_root_dir=os.getcwd(), max_epochs=2, - enable_model_summary=False, # we expect the second epoch to be skipped + enable_model_summary=False, ) trainer.fit(model, train_dataloaders=train_dataloader) assert trainer.global_step == 2 * yield_at_all + # we expect the second epoch to be skipped assert trainer.current_epoch == 1 From 0e0df602d4fab1b5e1af270a2d7cc98b05418c42 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 27 Nov 2021 04:35:53 +0100 Subject: [PATCH 02/14] Save the loop progress state by default --- pytorch_lightning/loops/base.py | 6 +++++- .../trainer/connectors/checkpoint_connector.py | 3 +-- tests/loops/test_loop_state_dict.py | 3 +++ tests/models/test_hooks.py | 2 ++ 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 3aa84962fd7a5..726dfccbd4b37 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -234,7 +234,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = destination[key] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, key + ".") - elif isinstance(v, ResultCollection): + elif self.trainer.state._fault_tolerant_mode.is_enabled and isinstance(v, ResultCollection): # sync / unsync metrics v.sync() destination[key] = v.state_dict() @@ -257,6 +257,10 @@ def load_state_dict( def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None: for k, v in self.__dict__.items(): key = prefix + k + if not self.trainer.state._fault_tolerant_mode.is_enabled or key not in state_dict: + # no state for this object, maybe we are loading an old checkpoint + continue + if isinstance(v, BaseProgress): v.load_state_dict(state_dict[key]) elif ( diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ab0d3aa4288fa..ec5a1098e6dce 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -371,9 +371,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: "global_step": global_step, "pytorch-lightning_version": pl.__version__, "state_dict": self._get_lightning_module_state_dict(), + "loops": self._get_loops_state_dict(), } - if _fault_tolerant_training(): - checkpoint["loops"] = self._get_loops_state_dict() if not weights_only: # dump callbacks diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 72eeb197e9e57..69daefcf4ea39 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from unittest import mock from unittest.mock import Mock import pytest @@ -37,6 +39,7 @@ def test_loops_state_dict(): assert fit_loop.state_dict() == new_fit_loop.state_dict() +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_loops_state_dict_structure(): trainer = Trainer() trainer.train_dataloader = Mock() diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index f49aefd370bf9..fc999943720e3 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -494,6 +494,7 @@ def training_step(self, batch, batch_idx): "optimizer_states": ANY, "pytorch-lightning_version": __version__, "state_dict": ANY, + "loops": ANY, } if kwargs.get("amp_backend") == "native": saved_ckpt["native_amp_scaling_state"] = ANY @@ -621,6 +622,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): "optimizer_states": ANY, "pytorch-lightning_version": __version__, "state_dict": ANY, + "loops": ANY, } # TODO: wrong saved epoch, should be 0 saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 2} From 334c64f006371cea19bd0075e94bbab75473b676 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 27 Nov 2021 04:41:45 +0100 Subject: [PATCH 03/14] Update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 264e66e278b6e..6655097642bcb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,7 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) -- +- Save the Loop's state and progress by default ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784)) ### Changed From 6823408103b078e51325b240ce108c47d7ef8267 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 27 Nov 2021 04:42:46 +0100 Subject: [PATCH 04/14] Update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6655097642bcb..adbd208aa1261 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,7 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) -- Save the Loop's state and progress by default ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784)) +- Save the `Loop`'s state and progress by default ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784)) ### Changed From c83be31d4ade2858b8edcbaf8a9d996294248786 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 30 Nov 2021 20:14:06 +0100 Subject: [PATCH 05/14] Fixes --- pytorch_lightning/loops/base.py | 8 +++++--- tests/trainer/test_trainer.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 726dfccbd4b37..50f8463352230 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -21,6 +21,7 @@ import pytorch_lightning as pl from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import BaseProgress +from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException T = TypeVar("T") # the output type of `run` @@ -234,7 +235,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = destination[key] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, key + ".") - elif self.trainer.state._fault_tolerant_mode.is_enabled and isinstance(v, ResultCollection): + elif isinstance(v, ResultCollection): # sync / unsync metrics v.sync() destination[key] = v.state_dict() @@ -255,9 +256,11 @@ def load_state_dict( v.load_state_dict(state_dict.copy(), prefix + k + ".") def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None: + # do not get the mode from `self.trainer` because it might not have been attached yet + ft_enabled = not _FaultTolerantMode.detect_current_mode().is_enabled for k, v in self.__dict__.items(): key = prefix + k - if not self.trainer.state._fault_tolerant_mode.is_enabled or key not in state_dict: + if not ft_enabled or key not in state_dict: # no state for this object, maybe we are loading an old checkpoint continue @@ -287,6 +290,5 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional if not self.trainer.is_global_zero: v.reset(metrics=False) - self.on_load_checkpoint(state_dict[prefix + "state_dict"]) self.restarting = True diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e440f5f703f75..89377385b3729 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -665,7 +665,7 @@ def test_benchmark_option(tmpdir): @pytest.mark.parametrize("ckpt_path", (None, "best", "specific")) @pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2)) @pytest.mark.parametrize("fn", ("validate", "test", "predict")) -def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn): +def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn): class TestModel(BoringModel): def validation_step(self, batch, batch_idx): self.log("foo", -batch_idx) From 6e2830b5e40cda0bea29fe06af2fe7f720bfb7b7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 30 Nov 2021 20:16:04 +0100 Subject: [PATCH 06/14] whitespace --- pytorch_lightning/loops/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 50f8463352230..3817d94ad9a1d 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -290,5 +290,6 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional if not self.trainer.is_global_zero: v.reset(metrics=False) + self.on_load_checkpoint(state_dict[prefix + "state_dict"]) self.restarting = True From 757b9a6b97b6bc3509e366a33ac7dd961a05e024 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 30 Nov 2021 20:41:48 +0100 Subject: [PATCH 07/14] Fix timer test --- pytorch_lightning/callbacks/timer.py | 7 +++++++ tests/callbacks/test_timer.py | 9 +++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index 810439b15bcbf..86c84d61e0ec1 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -142,6 +142,13 @@ def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._end_time[RunningStage.TESTING] = time.monotonic() + def on_fit_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + # this checks the time after the state is reloaded, regardless of the interval. + # this is necessary in case we load a state whose timer is already depleted + if self._duration is None: + return + self._check_time_remaining(trainer) + def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: if self._interval != Interval.step or self._duration is None: return diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index a1a8af0642982..c7c4e0458ee12 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -168,15 +168,12 @@ def test_timer_resume_training(tmpdir): assert trainer.current_epoch < 99 saved_global_step = trainer.global_step - # resume training (with depleted timer + # resume training (with depleted timer) timer = Timer(duration=timedelta(milliseconds=200)) - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[timer, checkpoint_callback], - ) + trainer = Trainer(default_root_dir=tmpdir, callbacks=timer) trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path) assert timer._offset > 0 - assert trainer.global_step == saved_global_step + 1 + assert trainer.global_step == saved_global_step @RunIf(skip_windows=True) From 82d7e72378ce30a6f445f5a9811a0940430db54d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 30 Nov 2021 20:44:23 +0100 Subject: [PATCH 08/14] Undo change --- tests/loops/test_loop_state_dict.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 69daefcf4ea39..72eeb197e9e57 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from unittest import mock from unittest.mock import Mock import pytest @@ -39,7 +37,6 @@ def test_loops_state_dict(): assert fit_loop.state_dict() == new_fit_loop.state_dict() -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_loops_state_dict_structure(): trainer = Trainer() trainer.train_dataloader = Mock() From 180a19f2468d1e7815d0e845b422e97b217df2b6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 30 Nov 2021 21:50:19 +0100 Subject: [PATCH 09/14] Fix test --- tests/loops/test_loops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 6338ed00e481d..3c0005a25810e 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -491,7 +491,9 @@ def configure_optimizers_multiple(self): # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the # fit loop to have an iterator, which is only available during training state_dict["epoch_loop.state_dict"]["dataloader_state_dict"] = ANY + state_dict["epoch_loop._results"] = ANY checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"]["dataloader_state_dict"] = ANY + checkpoint["loops"]["fit_loop"]["epoch_loop._results"] = ANY assert state_dict == checkpoint["loops"]["fit_loop"] trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) From 5a556f22045b66f03d8e81c581b527c75f546929 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Dec 2021 03:49:28 +0100 Subject: [PATCH 10/14] Fix silly bug --- pytorch_lightning/loops/base.py | 8 ++++---- tests/loops/test_loops.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 45fe9538240c6..6114320bfc224 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -274,9 +274,11 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di destination[prefix + "state_dict"] = self.on_save_checkpoint() + # do not get the mode from `self.trainer` because it might not have been attached yet + ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled for k, v in self.__dict__.items(): key = prefix + k - if isinstance(v, BaseProgress): + if ft_enabled and isinstance(v, BaseProgress): destination[key] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, key + ".") @@ -301,11 +303,9 @@ def load_state_dict( v.load_state_dict(state_dict.copy(), prefix + k + ".") def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None: - # do not get the mode from `self.trainer` because it might not have been attached yet - ft_enabled = not _FaultTolerantMode.detect_current_mode().is_enabled for k, v in self.__dict__.items(): key = prefix + k - if not ft_enabled or key not in state_dict: + if key not in state_dict: # no state for this object, maybe we are loading an old checkpoint continue diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 21411980649b7..be9989d062f21 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -213,6 +213,7 @@ def load_state_dict(self, state_dict: Dict) -> None: assert loop.outputs == list(range(10)) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_loop_hierarchy(): @dataclass class SimpleProgress(BaseProgress): @@ -535,9 +536,7 @@ def configure_optimizers_multiple(self): # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the # fit loop to have an iterator, which is only available during training state_dict["epoch_loop.state_dict"]["dataloader_state_dict"] = ANY - state_dict["epoch_loop._results"] = ANY checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"]["dataloader_state_dict"] = ANY - checkpoint["loops"]["fit_loop"]["epoch_loop._results"] = ANY assert state_dict == checkpoint["loops"]["fit_loop"] trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) From a8f229874db8c3b6b48016cc4c803975276ac1d4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Dec 2021 03:50:44 +0100 Subject: [PATCH 11/14] Fix CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 01348dcdf47d5..2a2f5df5c7cf7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,7 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) -- Save the `Loop`'s state and progress by default ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784)) +- Save the `Loop`'s state by default in the checkpoint ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784)) - Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324)) From f5228fd8af3e2af56b6dd58dc501254c1d8f25a7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Dec 2021 04:13:13 +0100 Subject: [PATCH 12/14] Fix test --- tests/loops/test_loop_state_dict.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index ed4f5169cb1cb..ced392f8c840b 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from unittest import mock from unittest.mock import Mock import pytest @@ -37,6 +39,7 @@ def test_loops_state_dict(): assert fit_loop.state_dict() == new_fit_loop.state_dict() +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_loops_state_dict_structure(): trainer = Trainer() trainer.train_dataloader = Mock() From 778d844a259f19794cb2e1efd4d709a9d95e2fd4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Dec 2021 04:23:56 +0100 Subject: [PATCH 13/14] Fix test --- pytorch_lightning/loops/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 6114320bfc224..d1bc1d557485f 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -337,4 +337,6 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional v.reset(metrics=False) self.on_load_checkpoint(state_dict[prefix + "state_dict"]) - self.restarting = True + + if _FaultTolerantMode.detect_current_mode().is_enabled: + self.restarting = True From 0d09d71e3d3680dec21c9bc184b41786e280b03b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 17 Dec 2021 08:02:56 -0500 Subject: [PATCH 14/14] Fix hang on rank zero only --- pytorch_lightning/core/lightning.py | 4 ++-- .../trainer/connectors/logger_connector/result.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 38010f7acf0a1..5a08d2e40fedd 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -303,7 +303,7 @@ def log( add_dataloader_idx: bool = True, batch_size: Optional[int] = None, metric_attribute: Optional[str] = None, - rank_zero_only: Optional[bool] = None, + rank_zero_only: bool = False, ) -> None: """Log a key, value pair. @@ -441,7 +441,7 @@ def log_dict( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, - rank_zero_only: Optional[bool] = None, + rank_zero_only: bool = False, ) -> None: """Log a dictionary of values at once. diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 4878099afc524..d0842ce023a67 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -616,8 +616,8 @@ def cpu(self) -> "ResultCollection": def sync(self) -> None: for result_metric in self.result_metrics: - if result_metric.is_tensor: - result_metric.sync() + if result_metric.is_tensor and not result_metric._is_synced: + result_metric.sync(should_sync=not result_metric.meta.sync.rank_zero_only) def unsync(self) -> None: for result_metric in self.result_metrics: