77import torch
88import torch .multiprocessing as mp
99
10- from pytorch_lightning .core .lightning import LightningModule
1110from pytorch_lightning .plugins .training_type .ddp_spawn import DDPSpawnPlugin
12- from pytorch_lightning .plugins .training_type .utils import on_colab_kaggle
1311from pytorch_lightning .utilities import _OMEGACONF_AVAILABLE , _TPU_AVAILABLE , rank_zero_warn
1412from pytorch_lightning .utilities .apply_func import apply_to_collection
1513from pytorch_lightning .utilities .distributed import rank_zero_only , ReduceOp
1816
1917if _TPU_AVAILABLE :
2018 import torch_xla .core .xla_model as xm
21- import torch_xla .distributed .parallel_loader as xla_pl
2219 import torch_xla .distributed .xla_multiprocessing as xmp
2320 from torch_xla .core .xla_model import rendezvous
24- from torch_xla .distributed .parallel_loader import ParallelLoader
21+ from torch_xla .distributed .parallel_loader import MpDeviceLoader
2522else :
26- xm , xla_pl , xmp , ParallelLoader , rendezvous = [None ] * 5
23+ xm , xmp , MpDeviceLoader , rendezvous = [None ] * 4
2724
2825if _OMEGACONF_AVAILABLE :
2926 from omegaconf import DictConfig , ListConfig , OmegaConf
3027
3128
3229class TPUSpawnPlugin (DDPSpawnPlugin ):
3330
34- def __init__ (
35- self ,
36- parallel_devices : Optional [List [torch .device ]] = None ,
37- num_nodes : int = 1 ,
38- ** kwargs : Dict [str , Any ]
39- ) -> None :
40- super ().__init__ (
41- parallel_devices , num_nodes = num_nodes , cluster_environment = None , sync_batchnorm = False , ** kwargs
42- )
31+ def __init__ (self , parallel_devices : Optional [List [int ]] = None , ** kwargs : Dict [str , Any ]) -> None :
32+ super ().__init__ (parallel_devices , num_nodes = 1 , cluster_environment = None , sync_batchnorm = False )
4333 self .tpu_local_core_rank = 0
4434 self .start_method = None
4535
@@ -61,10 +51,9 @@ def distributed_sampler_kwargs(self) -> dict:
6151 def is_distributed (self ):
6252 return self .world_size != 1
6353
64- def process_dataloader (self , dataloader : Union [Iterable , torch .utils .data .DataLoader ]) -> ParallelLoader :
54+ def process_dataloader (self , dataloader : Union [Iterable , torch .utils .data .DataLoader ]) -> MpDeviceLoader :
6555 device = xm .xla_device ()
66- dataloader = xla_pl .ParallelLoader (dataloader , [device ])
67- dataloader = dataloader .per_device_loader (device )
56+ dataloader = MpDeviceLoader (dataloader , device )
6857 return dataloader
6958
7059 def configure_ddp (self ) -> None :
@@ -104,7 +93,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
10493
10594 results = trainer .train_or_test_or_predict ()
10695
107- self .__save_end_of_training_weights (self .lightning_module )
10896 self .transfer_distrib_spawn_state_on_fit_end (results )
10997
11098 # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
@@ -114,12 +102,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
114102 if self .global_rank == 0 :
115103 time .sleep (2 )
116104
117- def __save_end_of_training_weights (self , model : LightningModule ) -> None :
118- # when training ends on these platforms dump weights to get out of the main process
119- if on_colab_kaggle ():
120- rank_zero_warn ("cleaning up... please do not interrupt" )
121- self .save_spawn_weights (model )
122-
123105 def model_to_device (self ) -> None :
124106 self ._model .to (xm .xla_device ())
125107
@@ -159,37 +141,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
159141 obj = torch .load (buffer )
160142 return obj
161143
162- def load_spawn_weights (self , original_model : LightningModule ) -> LightningModule :
163- """
164- Load the temp weights saved in the process
165- To recover the trained model from the ddp process we load the saved weights
166- """
167-
168- loaded_model = original_model
169-
170- if self .is_global_zero :
171- # load weights saved in ddp
172- path = os .path .join (original_model .trainer .default_root_dir , "__temp_weight_distributed_end.ckpt" )
173- loaded_model = original_model .__class__ .load_from_checkpoint (path )
174-
175- # copy loaded weights to old model
176- original_model .load_state_dict (loaded_model .state_dict ())
177-
178- # remove ddp weights
179- os .remove (path )
180-
181- return loaded_model
182-
183- def save_spawn_weights (self , model : LightningModule ) -> Optional [str ]:
184- """
185- Dump a temporary checkpoint after ddp ends to get weights out of the process
186- """
187- if model .trainer .is_global_zero :
188- path = os .path .join (model .trainer .default_root_dir , "__temp_weight_distributed_end.ckpt" )
189- model .trainer .save_checkpoint (path )
190- return path
191-
192- def reduce_decision (self , decision : bool ) -> bool :
144+ def reduce_boolean_decision (self , decision : bool ) -> bool :
193145 decision = torch .tensor (int (decision ), device = self .device )
194146 decision = self .reduce (decision , "sum" )
195147 decision = bool (decision == self .world_size )
@@ -213,40 +165,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
213165
214166 return output
215167
216- def post_dispatch (self ) -> None :
217- # TODO: Check if trainer references can be resolved otherwise
218- model = self .lightning_module
219-
220- # restore main state with best weights
221- best_path = self .mp_queue .get ()
222- last_path = self .mp_queue .get ()
223- self ._results = self .mp_queue .get ()
224-
225- # transfer back the best path to the trainer
226- if self .lightning_module .trainer .checkpoint_callback is not None :
227- self .lightning_module .trainer .checkpoint_callback .best_model_path = best_path
228- # todo, pass also bets score
229-
230- # load last weights
231- if last_path and not self .lightning_module .trainer .testing :
232- ckpt = torch .load (last_path , map_location = lambda storage , loc : storage )
233- model .load_state_dict (ckpt )
234-
235- self ._model = model
236-
237- # when training completes, load the weights back in main process
238- self .__load_weights_on_main_process ()
239-
240- def __load_weights_on_main_process (self ) -> None :
241- model = self .lightning_module
242-
243- # load weights if not interrupted
244- # TODO: check for trainer reference
245- if on_colab_kaggle () and not model .trainer .testing :
246- self .load_spawn_weights (model )
247-
248- self ._model = model
249-
250168 def _close_logger (self , trainer ) -> None :
251169 if trainer .logger is not None :
252170 trainer .logger .finalize ("success" )
0 commit comments