Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
Expand Down Expand Up @@ -286,7 +284,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
last_path = None
if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
atomic_save(state_dict, last_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)

# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
Expand All @@ -307,7 +305,7 @@ def __recover_child_process_weights(self, best_path, last_path):

# load last weights
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
ckpt = self.checkpoint_io.load_checkpoint(last_path, map_location=(lambda storage, loc: storage))
self.lightning_module.load_state_dict(ckpt)

def barrier(self, *args, **kwargs) -> None:
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
last_path = None
if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.save(state_dict, last_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)

if self.local_rank == 0:
# todo, pass complete checkpoint as state dictionary
Expand All @@ -231,9 +231,6 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
self.mp_queue.put(results)
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue

def save(self, state_dict: Dict, path: _PATH) -> None:
xm.save(state_dict, path)

def broadcast(self, obj: object, src: int = 0) -> object:
if not self.is_distributed:
return obj
Expand Down