1313# limitations under the License.
1414
1515import os
16+ from typing import Optional , Sequence
1617
1718import torch
1819
2627 DataParallelPlugin ,
2728 DDP2Plugin ,
2829 DDPPlugin ,
30+ DDPShardedPlugin ,
2931 DDPSpawnPlugin ,
32+ DDPSpawnShardedPlugin ,
3033 HorovodPlugin ,
3134 NativeMixedPrecisionPlugin ,
3235 PrecisionPlugin ,
36+ RPCPlugin ,
3337 ShardedNativeMixedPrecisionPlugin ,
3438 SingleDevicePlugin ,
3539 SingleTPUPlugin ,
3640 TPUHalfPrecisionPlugin ,
37- TPUSpawnPlugin , DDPShardedPlugin , DDPSpawnShardedPlugin ,
41+ TPUSpawnPlugin ,
42+ TrainingTypePlugin ,
43+ DDPShardedPlugin ,
44+ DDPSpawnShardedPlugin ,
3845)
3946from pytorch_lightning .plugins .environments import SLURMEnvironment , TorchElasticEnvironment
4047from pytorch_lightning .tuner .auto_gpu_select import pick_multiple_gpus
@@ -74,6 +81,7 @@ def __init__(
7481 amp_type ,
7582 amp_level ,
7683 cluster_environment ,
84+ plugins ,
7785 ):
7886 # initialization
7987 self ._device_type = DeviceType .CPU
@@ -95,6 +103,11 @@ def __init__(
95103 self .cluster_environment = cluster_environment
96104 self .is_slurm_managing_tasks = False
97105
106+ self ._precision_plugin : Optional [PrecisionPlugin ] = None
107+ self ._training_type_plugin : Optional [TrainingTypePlugin ] = None
108+
109+ self .handle_given_plugins (plugins )
110+
98111 # init the default rank if exists
99112 # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
100113 # this way we only show it on rank 0
@@ -136,6 +149,56 @@ def __init__(
136149
137150 self .replace_sampler_ddp = replace_sampler_ddp
138151
152+ def handle_given_plugins (self , plugins : Optional [Sequence ]):
153+ if plugins is None :
154+ return
155+
156+ if not isinstance (plugins , Sequence ):
157+ plugins = [plugins ]
158+
159+ training_type = None
160+ precision = None
161+
162+ for plug in plugins :
163+ if isinstance (plug , TrainingTypePlugin ):
164+ if training_type is None :
165+ training_type = plug
166+ else :
167+ raise MisconfigurationException (
168+ 'You can only specify one precision and one training type plugin. '
169+ 'Found more than 1 training type plugin'
170+ )
171+ elif isinstance (plug , PrecisionPlugin ):
172+ if precision is None :
173+ precision = plug
174+ else :
175+ raise MisconfigurationException (
176+ 'You can only specify one precision and one training type plugin. '
177+ 'Found more than 1 precision plugin'
178+ )
179+ else :
180+ raise MisconfigurationException (
181+ f'Found invalid type for plugin { plug } . '
182+ 'Expected a precision or training type plugin.'
183+ )
184+
185+ self ._training_type_plugin = training_type
186+ self ._precision_plugin = precision
187+
188+ @property
189+ def precision_plugin (self ) -> PrecisionPlugin :
190+ if self ._precision_plugin is None :
191+ self ._precision_plugin = self .select_precision_plugin ()
192+
193+ return self ._precision_plugin
194+
195+ @property
196+ def training_type_plugin (self ) -> TrainingTypePlugin :
197+ if self ._training_type_plugin is None :
198+ self ._training_type_plugin = self .select_training_type_plugin ()
199+
200+ return self ._training_type_plugin
201+
139202 @property
140203 def on_cpu (self ):
141204 return self ._device_type == DeviceType .CPU
@@ -205,6 +268,9 @@ def select_precision_plugin(self):
205268 if self .on_tpu :
206269 return TPUHalfPrecisionPlugin ()
207270
271+ if isinstance (self .training_type_plugin , RPCPlugin ):
272+ raise MisconfigurationException
273+
208274 if self .amp_type == "native" :
209275 if not _NATIVE_AMP_AVAILABLE :
210276 rank_zero_warn (
@@ -215,7 +281,7 @@ def select_precision_plugin(self):
215281 self .amp_type = "apex"
216282 else :
217283 log .info ("Using native 16bit precision." )
218- if self .distributed_backend == "ddp_sharded" or self . distributed_backend == "ddp_sharded_spawn" :
284+ if isinstance ( self .training_type_plugin , ( DDPShardedPlugin , DDPSpawnShardedPlugin )) :
219285 return ShardedNativeMixedPrecisionPlugin ()
220286 self .amp_type = AMPType .NATIVE
221287 return NativeMixedPrecisionPlugin ()
@@ -227,7 +293,7 @@ def select_precision_plugin(self):
227293 " Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
228294 )
229295 else :
230- if self .distributed_backend == "ddp_sharded" or self . distributed_backend == "ddp_sharded_spawn" :
296+ if isinstance ( self .training_type_plugin , ( DDPShardedPlugin , DDPSpawnShardedPlugin )) :
231297 raise MisconfigurationException (
232298 "Sharded Plugin is not supported with Apex AMP, "
233299 "please using native AMP for 16-bit precision."
@@ -289,6 +355,12 @@ def select_training_type_plugin(self):
289355 def select_accelerator (self ):
290356 if isinstance (self .distributed_backend , Accelerator ):
291357 # custom accelerator from user
358+ if self ._precision_plugin is not None or self ._training_type_plugin is not None :
359+ # plugins also specified by user
360+ rank_zero_warn (
361+ 'Specified Precision and TrainingType Plugins will be ignored, '
362+ 'since an Accelerator instance was provided'
363+ )
292364 return self .distributed_backend
293365
294366 if self .on_gpu :
@@ -299,8 +371,8 @@ def select_accelerator(self):
299371 acc_cls = CPUAccelerator
300372
301373 return acc_cls (
302- precision_plugin = self .select_precision_plugin () ,
303- training_type_plugin = self .select_training_type_plugin () ,
374+ precision_plugin = self .precision_plugin ,
375+ training_type_plugin = self .training_type_plugin ,
304376 )
305377
306378 def select_cluster_environment (self ):
0 commit comments