2727from pytorch_lightning .accelerators .accelerator import Accelerator , ReduceOp
2828from pytorch_lightning .core .lightning import LightningModule
2929from pytorch_lightning .distributed .dist import LightningDistributed
30+ from pytorch_lightning .plugins .rpc_plugin import RPCPlugin
3031from pytorch_lightning .utilities import HYDRA_AVAILABLE , AMPType
3132from pytorch_lightning .utilities .distributed import find_free_network_port , rank_zero_only , sync_ddp_if_available
3233from pytorch_lightning .utilities .exceptions import MisconfigurationException
@@ -162,8 +163,11 @@ def _step(self, args):
162163 return output
163164
164165 def barrier (self , name : Optional [str ] = None ):
165- if torch_distrib .is_initialized ():
166- torch_distrib .barrier ()
166+ if self .rpc_enabled :
167+ # Allow RPC to handle barrier on main RPC processes
168+ self .ddp_plugin .barrier ()
169+ elif torch_distrib .is_initialized ():
170+ torch_distrib .barrier (group = self .ddp_plugin .data_parallel_group )
167171
168172 def _check_can_spawn_children (self ):
169173 if self ._has_spawned_children :
@@ -177,9 +181,11 @@ def set_world_ranks(self, process_idx):
177181 self .trainer .global_rank = self .trainer .node_rank * self .trainer .num_processes + process_idx
178182 self .trainer .world_size = self .trainer .num_nodes * self .trainer .num_processes
179183
180- def model_to_device (self , model , process_idx ):
184+ def init_device (self , process_idx ):
181185 self .trainer .root_gpu = self .trainer .data_parallel_device_ids [self .trainer .local_rank ]
182186 torch .cuda .set_device (self .trainer .root_gpu )
187+
188+ def model_to_device (self , model ):
183189 model .cuda (self .trainer .root_gpu )
184190
185191 def get_device_ids (self ):
@@ -192,12 +198,12 @@ def on_train_end(self):
192198 def early_stopping_should_stop (self , pl_module ):
193199 stop = torch .tensor (int (self .trainer .should_stop ), device = pl_module .device )
194200 torch_distrib .all_reduce (stop , op = torch_distrib .reduce_op .SUM )
195- torch_distrib .barrier ()
201+ self .barrier ('early_stopping' )
196202 should_stop = stop == self .trainer .world_size
197203 return should_stop
198204
199205 def broadcast (self , obj , src = 0 ):
200- return self .dist .broadcast (obj )
206+ return self .dist .broadcast (obj , group = self . ddp_plugin . data_parallel_group )
201207
202208 def ddp_train (self , process_idx , model ):
203209 """
@@ -226,6 +232,9 @@ def ddp_train(self, process_idx, model):
226232 # set warning rank
227233 rank_zero_only .rank = self .trainer .global_rank
228234
235+ # Initialize cuda device
236+ self .init_device (process_idx )
237+
229238 # set up server using proc 0's ip address
230239 # try to init for 20 times at max in case ports are taken
231240 # where to store ip_table
@@ -236,6 +245,15 @@ def ddp_train(self, process_idx, model):
236245 self .trainer .is_slurm_managing_tasks
237246 )
238247
248+ if isinstance (self .ddp_plugin , RPCPlugin ):
249+ if not self .ddp_plugin .is_main_rpc_process :
250+ self .ddp_plugin .on_accelerator_exit_rpc_process (self .trainer )
251+ self .ddp_plugin .exit_rpc_process ()
252+ if self .ddp_plugin .return_after_exit_rpc_process :
253+ return
254+ else :
255+ self .ddp_plugin .on_main_rpc_connection (self .trainer )
256+
239257 # call setup after the ddp process has connected
240258 self .trainer .call_setup_hook (model )
241259
@@ -251,7 +269,7 @@ def ddp_train(self, process_idx, model):
251269 model = self .configure_sync_batchnorm (model )
252270
253271 # move the model to the correct device
254- self .model_to_device (model , process_idx )
272+ self .model_to_device (model )
255273
256274 # CHOOSE OPTIMIZER
257275 # allow for lr schedulers as well
@@ -284,7 +302,7 @@ def ddp_train(self, process_idx, model):
284302 return results
285303
286304 def configure_ddp (
287- self , model : LightningModule , device_ids : List [int ]
305+ self , model : LightningModule , device_ids : List [int ]
288306 ) -> DistributedDataParallel :
289307 model = self .ddp_plugin .configure_ddp (model , device_ids )
290308 return model
@@ -317,3 +335,17 @@ def sync_tensor(self,
317335
318336 def get_reference_model (self , model ) -> LightningModule :
319337 return self .ddp_plugin .get_model_from_plugin (model )
338+
339+ @property
340+ def distributed_sampler_kwargs (self ):
341+ distributed_sampler_kwargs = dict (
342+ num_replicas = self .trainer .num_nodes * self .trainer .num_processes ,
343+ rank = self .trainer .global_rank
344+ )
345+ if self .ddp_plugin is not None :
346+ distributed_sampler_kwargs = self .ddp_plugin .distributed_sampler_kwargs (distributed_sampler_kwargs )
347+ return distributed_sampler_kwargs
348+
349+ @property
350+ def require_distributed_sampler (self ):
351+ return True
0 commit comments