diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index b958d4808e1c2..0db0dd8784a06 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -34,8 +34,6 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import ( @@ -286,7 +284,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul last_path = None if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(state_dict, last_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) @@ -307,7 +305,7 @@ def __recover_child_process_weights(self, best_path, last_path): # load last weights if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: - ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) + ckpt = self.checkpoint_io.load_checkpoint(last_path, map_location=(lambda storage, loc: storage)) self.lightning_module.load_state_dict(ckpt) def barrier(self, *args, **kwargs) -> None: diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a7258fc7123c1..1afb8e6c6d767 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -222,7 +222,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul last_path = None if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.save(state_dict, last_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) if self.local_rank == 0: # todo, pass complete checkpoint as state dictionary @@ -231,9 +231,6 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul self.mp_queue.put(results) self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue - def save(self, state_dict: Dict, path: _PATH) -> None: - xm.save(state_dict, path) - def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: return obj