2020import torch
2121import torch .multiprocessing as mp
2222
23- from pytorch_lightning .core .lightning import LightningModule
2423from pytorch_lightning .plugins .training_type .ddp_spawn import DDPSpawnPlugin
25- from pytorch_lightning .plugins .training_type .utils import on_colab_kaggle
2624from pytorch_lightning .trainer .states import TrainerState
2725from pytorch_lightning .utilities import _OMEGACONF_AVAILABLE , _TPU_AVAILABLE , rank_zero_warn
2826from pytorch_lightning .utilities .apply_func import apply_to_collection
3230
3331if _TPU_AVAILABLE :
3432 import torch_xla .core .xla_model as xm
35- import torch_xla .distributed .parallel_loader as xla_pl
3633 import torch_xla .distributed .xla_multiprocessing as xmp
3734 from torch_xla .core .xla_model import rendezvous
38- from torch_xla .distributed .parallel_loader import ParallelLoader
35+ from torch_xla .distributed .parallel_loader import MpDeviceLoader
3936else :
40- xm , xla_pl , xmp , ParallelLoader , rendezvous = [None ] * 5
37+ xm , xmp , MpDeviceLoader , rendezvous = [None ] * 4
4138
4239if _OMEGACONF_AVAILABLE :
4340 from omegaconf import DictConfig , ListConfig , OmegaConf
4441
4542
4643class TPUSpawnPlugin (DDPSpawnPlugin ):
4744
48- def __init__ (
49- self ,
50- parallel_devices : Optional [List [torch .device ]] = None ,
51- num_nodes : int = 1 ,
52- ** kwargs : Dict [str , Any ]
53- ) -> None :
54- super ().__init__ (
55- parallel_devices , num_nodes = num_nodes , cluster_environment = None , sync_batchnorm = False , ** kwargs
56- )
45+ def __init__ (self , parallel_devices : Optional [List [int ]] = None , ** kwargs : Dict [str , Any ]) -> None :
46+ super ().__init__ (parallel_devices , num_nodes = 1 , cluster_environment = None , sync_batchnorm = False )
5747 self .tpu_local_core_rank = 0
5848 self .start_method = None
5949
@@ -74,10 +64,9 @@ def distributed_sampler_kwargs(self) -> dict:
7464 def is_distributed (self ):
7565 return self .world_size != 1
7666
77- def process_dataloader (self , dataloader : Union [Iterable , torch .utils .data .DataLoader ]) -> ParallelLoader :
67+ def process_dataloader (self , dataloader : Union [Iterable , torch .utils .data .DataLoader ]) -> MpDeviceLoader :
7868 device = xm .xla_device ()
79- dataloader = xla_pl .ParallelLoader (dataloader , [device ])
80- dataloader = dataloader .per_device_loader (device )
69+ dataloader = MpDeviceLoader (dataloader , device )
8170 return dataloader
8271
8372 def configure_ddp (self ) -> None :
@@ -115,7 +104,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
115104
116105 results = trainer .run_stage ()
117106
118- self .__save_end_of_training_weights (self .lightning_module )
119107 self .transfer_distrib_spawn_state_on_fit_end (results )
120108
121109 # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
@@ -125,12 +113,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
125113 if self .global_rank == 0 :
126114 time .sleep (2 )
127115
128- def __save_end_of_training_weights (self , model : LightningModule ) -> None :
129- # when training ends on these platforms dump weights to get out of the main process
130- if on_colab_kaggle ():
131- rank_zero_warn ("cleaning up... please do not interrupt" )
132- self .save_spawn_weights (model )
133-
134116 def model_to_device (self ) -> None :
135117 self ._model .to (xm .xla_device ())
136118
@@ -172,37 +154,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
172154 obj = torch .load (buffer )
173155 return obj
174156
175- def load_spawn_weights (self , original_model : LightningModule ) -> LightningModule :
176- """
177- Load the temp weights saved in the process
178- To recover the trained model from the ddp process we load the saved weights
179- """
180-
181- loaded_model = original_model
182-
183- if self .is_global_zero :
184- # load weights saved in ddp
185- path = os .path .join (original_model .trainer .default_root_dir , "__temp_weight_distributed_end.ckpt" )
186- loaded_model = original_model .__class__ .load_from_checkpoint (path )
187-
188- # copy loaded weights to old model
189- original_model .load_state_dict (loaded_model .state_dict ())
190-
191- # remove ddp weights
192- os .remove (path )
193-
194- return loaded_model
195-
196- def save_spawn_weights (self , model : LightningModule ) -> Optional [str ]:
197- """
198- Dump a temporary checkpoint after ddp ends to get weights out of the process
199- """
200- if model .trainer .is_global_zero :
201- path = os .path .join (model .trainer .default_root_dir , "__temp_weight_distributed_end.ckpt" )
202- model .trainer .save_checkpoint (path )
203- return path
204-
205- def reduce_decision (self , decision : bool ) -> bool :
157+ def reduce_boolean_decision (self , decision : bool ) -> bool :
206158 decision = torch .tensor (int (decision ), device = self .device )
207159 decision = self .reduce (decision , "sum" )
208160 decision = bool (decision == self .world_size )
@@ -226,39 +178,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
226178
227179 return output
228180
229- def post_dispatch (self ) -> None :
230- # TODO: Check if trainer references can be resolved otherwise
231- model = self .lightning_module
232-
233- # restore main state with best weights
234- best_path = self .mp_queue .get ()
235- last_path = self .mp_queue .get ()
236- self ._results = self .mp_queue .get ()
237-
238- # transfer back the best path to the trainer
239- if self .lightning_module .trainer .checkpoint_callback is not None :
240- self .lightning_module .trainer .checkpoint_callback .best_model_path = best_path
241- # todo, pass also bets score
242-
243- # load last weights
244- if last_path and model .trainer .state == TrainerState .FITTING :
245- ckpt = torch .load (last_path , map_location = lambda storage , loc : storage )
246- model .load_state_dict (ckpt )
247-
248- self ._model = model
249-
250- # when training completes, load the weights back in main process
251- self .__load_weights_on_main_process ()
252-
253- def __load_weights_on_main_process (self ) -> None :
254- model = self .lightning_module
255-
256- # load weights if not interrupted
257- if on_colab_kaggle () and model .trainer .state == TrainerState .FITTING :
258- self .load_spawn_weights (model )
259-
260- self ._model = model
261-
262181 def _close_logger (self , trainer ) -> None :
263182 if trainer .logger is not None :
264183 trainer .logger .finalize ("success" )
0 commit comments