3333 HorovodPlugin ,
3434 NativeMixedPrecisionPlugin ,
3535 PrecisionPlugin ,
36- RPCPlugin ,
3736 ShardedNativeMixedPrecisionPlugin ,
3837 SingleDevicePlugin ,
3938 SingleTPUPlugin ,
@@ -116,11 +115,11 @@ def __init__(
116115 self .parallel_device_ids = device_parser .parse_gpu_ids (self .gpus )
117116 self .root_gpu = device_parser .determine_root_gpu_device (self .parallel_device_ids )
118117
119- self .handle_given_plugins (plugins )
120-
121118 self .set_distributed_mode ()
122119 self .configure_slurm_ddp ()
123120
121+ self .handle_given_plugins (plugins )
122+
124123 self .accelerator = self .select_accelerator ()
125124
126125 # override dist backend when using tpus
@@ -147,8 +146,10 @@ def __init__(
147146 self .replace_sampler_ddp = replace_sampler_ddp
148147
149148 def handle_given_plugins (self , plugins : Optional [Sequence ]):
150- if plugins is None :
151- return
149+ plugins = plugins if plugins is not None else []
150+
151+ if isinstance (plugins , str ):
152+ plugins = [plugins ]
152153
153154 if not isinstance (plugins , Sequence ):
154155 plugins = [plugins ]
@@ -158,7 +159,10 @@ def handle_given_plugins(self, plugins: Optional[Sequence]):
158159 cluster_environment = None
159160
160161 for plug in plugins :
161- if isinstance (plug , TrainingTypePlugin ):
162+ if isinstance (plug , str ):
163+ self .set_distributed_mode (plug )
164+
165+ elif isinstance (plug , TrainingTypePlugin ):
162166 if training_type is None :
163167 training_type = plug
164168
@@ -191,6 +195,7 @@ def handle_given_plugins(self, plugins: Optional[Sequence]):
191195 )
192196
193197 self ._training_type_plugin = training_type
198+ self ._training_type_plugin = self .training_type_plugin
194199 self ._precision_plugin = precision
195200 self ._cluster_environment = cluster_environment or self .select_cluster_environment ()
196201
@@ -206,6 +211,7 @@ def training_type_plugin(self) -> TrainingTypePlugin:
206211 self ._training_type_plugin = self .select_training_type_plugin ()
207212 else :
208213 self ._training_type_plugin = self .resolve_training_type_plugin (self ._training_type_plugin )
214+
209215 return self ._training_type_plugin
210216
211217 @property
@@ -327,7 +333,7 @@ def select_precision_plugin(self):
327333
328334 def select_training_type_plugin (self ):
329335 if self .use_ddp2 :
330- plugin = DDP2Plugin (parallel_devices = self .parallel_devices , cluster_environment = self ._cluster_environment )
336+ plugin = DDP2Plugin (parallel_devices = self .parallel_devices , cluster_environment = self .cluster_environment )
331337 elif self .use_ddp :
332338 use_slurm_ddp = self .use_ddp and self .is_slurm_managing_tasks
333339 use_torchelastic_ddp = self .use_ddp and self .is_using_torchelastic
@@ -359,7 +365,7 @@ def select_training_type_plugin(self):
359365 plugin = ddp_plugin_cls (
360366 parallel_devices = self .parallel_devices ,
361367 num_nodes = self .num_nodes ,
362- cluster_environment = self .select_cluster_environment () ,
368+ cluster_environment = self .cluster_environment ,
363369 sync_batchnorm = self .sync_batchnorm ,
364370 )
365371 elif self .use_dp :
@@ -425,7 +431,11 @@ def select_cluster_environment(self):
425431 env = TorchElasticEnvironment ()
426432 return env
427433
428- def set_distributed_mode (self ):
434+ def set_distributed_mode (self , distributed_backend : Optional [str ] = None ):
435+
436+ if distributed_backend is not None :
437+ self .distributed_backend = distributed_backend
438+
429439 if isinstance (self .distributed_backend , Accelerator ):
430440 return
431441
@@ -484,6 +494,9 @@ def set_distributed_mode(self):
484494 ):
485495 self .num_processes = self .num_gpus
486496
497+ if (self ._device_type == DeviceType .GPU and self ._distrib_type == DistributedType .DDP2 ):
498+ self .num_processes = self .num_nodes
499+
487500 # Horovod is an extra case...
488501 if self .distributed_backend == "horovod" :
489502 self ._set_horovod_backend ()
0 commit comments