From c3d842ebd19e49d770dbd1d96d0a599a8be7944f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 16:59:30 +0100 Subject: [PATCH 1/6] use checkpoint io plugin for saving --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 3 +-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index fad41f12302f2..2158d4fabd0ef 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -35,7 +35,6 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device -from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as _group @@ -290,7 +289,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul last_path = None if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(state_dict, last_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 5ef8a46d7127f..a5ac08f6c2ade 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -221,7 +221,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul last_path = None if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.save(state_dict, last_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) if self.local_rank == 0: # todo, pass complete checkpoint as state dictionary From 22c790a44d644d7c9358b043b7d8064437b8f8df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 17:37:35 +0100 Subject: [PATCH 2/6] share same method in spawn and tpu spawn --- .../plugins/training_type/ddp_spawn.py | 2 +- .../plugins/training_type/tpu_spawn.py | 23 ------------------- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 2158d4fabd0ef..ecb3f05033302 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -282,7 +282,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.global_rank == 0 and self.mp_queue is not None: + if self.local_rank == 0 and self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") # save the last weights diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a5ac08f6c2ade..b1133bfa1f1cb 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -207,29 +207,6 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: - checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None - - # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self.lightning_module.state_dict() - - if self.mp_queue is not None: - rank_zero_warn("cleaning up tpu spawn environment...") - - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) - - if self.local_rank == 0: - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) - self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue - def save(self, state_dict: Dict, path: _PATH) -> None: xm.save(state_dict, path) From 2b74c927e94748a794629c0b09027d5fd234f6f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 21:43:33 +0100 Subject: [PATCH 3/6] remove --- .../plugins/training_type/tpu_spawn.py | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1afb8e6c6d767..888941f892b3b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -208,29 +208,6 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: - checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None - - # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self.lightning_module.state_dict() - - if self.mp_queue is not None: - rank_zero_warn("cleaning up tpu spawn environment...") - - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) - - if self.local_rank == 0: - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) - self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue - def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: return obj From 42e659dba09955ad6b4db79f0123573f93e3b6d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 21:53:25 +0100 Subject: [PATCH 4/6] make equivalent --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 84832b17c0334..3bfe89be22467 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -277,7 +277,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.local_rank == 0 and self.mp_queue is not None: + if self.should_rank_save_checkpoint and self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") # save the last weights From 7192b13af46d6d74b50aa0322a4b4abf7968d8db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 22:01:44 +0100 Subject: [PATCH 5/6] verify --- .../plugins/training_type/tpu_spawn.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 888941f892b3b..81e15723da54d 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -208,6 +208,29 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) + def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + checkpoint_callback = trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + + # requires to compute the state_dict on all processes in case Metrics are present + state_dict = self.lightning_module.state_dict() + + assert self.mp_queue is not None + rank_zero_warn("cleaning up tpu spawn environment...") + + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) + + if self.local_rank == 0: + # todo, pass complete checkpoint as state dictionary + self.mp_queue.put(best_model_path) + self.mp_queue.put(last_path) + self.mp_queue.put(results) + self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue + def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: return obj From 4a8c9ffaa7f03fc69636a0278720be47440e912b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 22:43:11 +0100 Subject: [PATCH 6/6] pull rank check up --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 81e15723da54d..7f81bd15e012b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -218,13 +218,13 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul assert self.mp_queue is not None rank_zero_warn("cleaning up tpu spawn environment...") - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) - if self.local_rank == 0: + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) + # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) self.mp_queue.put(last_path)