diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 0db0dd8784a06..3bfe89be22467 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -277,7 +277,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.global_rank == 0 and self.mp_queue is not None: + if self.should_rank_save_checkpoint and self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") # save the last weights diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1afb8e6c6d767..7f81bd15e012b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -215,21 +215,21 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.mp_queue is not None: - rank_zero_warn("cleaning up tpu spawn environment...") + assert self.mp_queue is not None + rank_zero_warn("cleaning up tpu spawn environment...") + if self.local_rank == 0: # save the last weights last_path = None if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) self.checkpoint_io.save_checkpoint(state_dict, last_path) - if self.local_rank == 0: - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) - self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue + # todo, pass complete checkpoint as state dictionary + self.mp_queue.put(best_model_path) + self.mp_queue.put(last_path) + self.mp_queue.put(results) + self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: