44from typing import Any , Dict , Iterable , List , Optional , Union
55
66import torch
7+ import torch .distributed as torch_distrib
78import torch .multiprocessing as mp
89
910from pytorch_lightning .core .lightning import LightningModule
@@ -112,7 +113,8 @@ def model_to_device(self) -> None:
112113 self ._model .to (xm .xla_device ())
113114
114115 def barrier (self , name : Optional [str ] = None ) -> None :
115- rendezvous (f"pl.Trainer.{ name } " )
116+ if torch_distrib .is_initialized ():
117+ rendezvous (f"pl.Trainer.{ name } " )
116118
117119 def transfer_distrib_spawn_state_on_fit_end (self , results ):
118120 # TODO: is there a better way than accessing callback through model -> trainer -> callback?
@@ -126,14 +128,26 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
126128 # TODO: is there a better way than accessing trainer through model -> trainer?
127129 if not self .lightning_module .trainer .testing and best_model_path is not None and len (best_model_path ) > 0 :
128130 last_path = re .sub (".ckpt" , ".tmp_end.ckpt" , best_model_path )
129- xm .save (self .lightning_module .state_dict (), last_path )
131+ self .save (self .lightning_module .state_dict (), last_path )
130132
131133 if self .global_rank == 0 :
132134 # todo, pass complete checkpoint as state dictionary
133135 self .mp_queue .put (best_model_path )
134136 self .mp_queue .put (last_path )
135137 self .mp_queue .put (results )
136138
139+ def save (self , state_dict : Dict , path : str ) -> None :
140+ """
141+ Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``.
142+ The rendez-vous doesn't affect directly saving.
143+ We can ignore the ``RuntimeError`` to reduce friction with TPUs.
144+ """
145+ try :
146+ xm .save (state_dict , path )
147+ except RuntimeError as e :
148+ if "Failed to meet rendezvous" not in str (e ):
149+ raise e
150+
137151 def broadcast (self , obj : object , src : int = 0 ) -> object :
138152 buffer = io .BytesIO ()
139153 torch .save (obj , buffer )
@@ -281,4 +295,4 @@ def save_checkpoint(self, filepath, weights_only: bool = False):
281295 # dump states as a checkpoint dictionary object
282296 _checkpoint = self .lightning_module .trainer .checkpoint_connector .dump_checkpoint (weights_only )
283297 # Todo: TypeError: 'mappingproxy' object does not support item assignment
284- xm .save ({k : v for k , v in _checkpoint .items () if k != "callbacks" }, filepath )
298+ self .save ({k : v for k , v in _checkpoint .items () if k != "callbacks" }, filepath )
0 commit comments