@@ -92,7 +92,6 @@ def __init__(
9292 devices ,
9393 tpu_cores ,
9494 ipus ,
95- distributed_backend ,
9695 accelerator ,
9796 strategy : Optional [Union [str , TrainingTypePlugin ]],
9897 gpus ,
@@ -113,7 +112,8 @@ def __init__(
113112 self ._accelerator_type = None
114113
115114 self .strategy = strategy .lower () if isinstance (strategy , str ) else strategy
116- self .distributed_backend = distributed_backend or accelerator
115+ # TODO: Rename this to something else once all the distributed flags are moved to strategy
116+ self .distributed_backend = accelerator
117117
118118 self ._init_deterministic (deterministic )
119119
@@ -152,7 +152,7 @@ def __init__(
152152
153153 self .plugins = plugins
154154
155- self ._handle_accelerator_and_distributed_backend ( distributed_backend , accelerator )
155+ self ._handle_accelerator_and_strategy ( )
156156
157157 self ._validate_accelerator_and_devices ()
158158
@@ -176,10 +176,6 @@ def __init__(
176176 self ._training_type_plugin_resolved = False
177177 self .accelerator = self .select_accelerator ()
178178
179- # override dist backend when using tpus
180- if self .use_tpu :
181- self .distributed_backend = "tpu"
182-
183179 # init flags for SLURM+DDP to work
184180 self .world_size = 1
185181 self .interactive_ddp_procs = []
@@ -285,31 +281,16 @@ def _set_devices_if_none(self) -> None:
285281 elif self ._accelerator_type == DeviceType .CPU :
286282 self .devices = self .num_processes
287283
288- def _handle_accelerator_and_distributed_backend (
289- self , distributed_backend : Optional [str ], accelerator : Optional [Union [str , Accelerator ]]
290- ) -> None :
291- if distributed_backend is not None :
292- rank_zero_deprecation (
293- f"`Trainer(distributed_backend={ distributed_backend !r} )` "
294- "has been deprecated and will be removed in v1.5."
295- f" Use `Trainer(strategy={ distributed_backend !r} )` instead."
296- )
297- if self .strategy is not None :
298- raise MisconfigurationException (
299- f"You have passed `Trainer(strategy={ self .strategy !r} )` but have"
300- f" also passed `Trainer(distributed_backend={ distributed_backend !r} )`."
301- f" HINT: Use just `Trainer(strategy={ self .strategy !r} )` instead."
302- )
303-
304- if accelerator is not None and accelerator in list (DistributedType ):
284+ def _handle_accelerator_and_strategy (self ) -> None :
285+ if self .distributed_backend is not None and self .distributed_backend in list (DistributedType ):
305286 rank_zero_deprecation (
306- f"Passing `Trainer(accelerator={ accelerator !r} )` has been deprecated"
307- f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ accelerator !r} )` instead."
287+ f"Passing `Trainer(accelerator={ self . distributed_backend !r} )` has been deprecated"
288+ f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ self . distributed_backend !r} )` instead."
308289 )
309290 if self .strategy is not None :
310291 raise MisconfigurationException (
311292 f"You have passed `Trainer(strategy={ self .strategy !r} )` but have"
312- f" also passed `Trainer(accelerator={ accelerator !r} )`."
293+ f" also passed `Trainer(accelerator={ self . distributed_backend !r} )`."
313294 f" HINT: Use just `Trainer(strategy={ self .strategy !r} )` instead."
314295 )
315296
@@ -783,15 +764,15 @@ def select_cluster_environment(self) -> ClusterEnvironment:
783764 env = LightningEnvironment ()
784765 return env
785766
786- def set_distributed_mode (self , distributed_backend : Optional [str ] = None ):
767+ def set_distributed_mode (self , strategy : Optional [str ] = None ):
787768
788- if distributed_backend is None and self .is_training_type_in_plugins :
769+ if strategy is None and self .is_training_type_in_plugins :
789770 return
790771
791- if distributed_backend is not None and distributed_backend in TrainingTypePluginsRegistry :
792- self .distributed_backend = TrainingTypePluginsRegistry [distributed_backend ]["distributed_backend" ]
793- elif distributed_backend is not None :
794- self .distributed_backend = distributed_backend
772+ if strategy is not None and strategy in TrainingTypePluginsRegistry :
773+ self .distributed_backend = TrainingTypePluginsRegistry [strategy ]["distributed_backend" ]
774+ elif strategy is not None :
775+ self .distributed_backend = strategy
795776
796777 if isinstance (self .distributed_backend , Accelerator ):
797778 return
0 commit comments