Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
# requires to compute the state_dict on all processes in case Metrics are present
state_dict = self.lightning_module.state_dict()

if self.global_rank == 0 and self.mp_queue is not None:
if self.should_rank_save_checkpoint and self.mp_queue is not None:
rank_zero_warn("cleaning up ddp environment...")

# save the last weights
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,21 +215,21 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
# requires to compute the state_dict on all processes in case Metrics are present
state_dict = self.lightning_module.state_dict()

if self.mp_queue is not None:
rank_zero_warn("cleaning up tpu spawn environment...")
assert self.mp_queue is not None
rank_zero_warn("cleaning up tpu spawn environment...")

if self.local_rank == 0:
# save the last weights
last_path = None
if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)

if self.local_rank == 0:
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue

def broadcast(self, obj: object, src: int = 0) -> object:
if not self.is_distributed:
Expand Down