1212from pytorch_lightning .plugins .precision import PrecisionPlugin
1313from pytorch_lightning .strategies .ddp import DDPStrategy
1414from pytorch_lightning .strategies .strategy import TBroadcast
15+ from pytorch_lightning .trainer .states import TrainerFn
1516from pytorch_lightning .utilities .distributed import ReduceOp
1617from pytorch_lightning .utilities .exceptions import MisconfigurationException
1718from pytorch_lightning .utilities .imports import _BAGUA_AVAILABLE
19+ from pytorch_lightning .utilities .optimizer import optimizers_to_device
1820from pytorch_lightning .utilities .seed import reset_seed
1921
2022if _BAGUA_AVAILABLE :
@@ -148,6 +150,35 @@ def _set_node_environment_variables(self) -> None:
148150 os .environ ["WORLD_SIZE" ] = str (self .world_size )
149151 os .environ ["LOCAL_RANK" ] = str (self .local_rank )
150152
153+ def setup (self , trainer : "pl.Trainer" ) -> None :
154+ self ._rank_0_will_call_children_scripts = self .broadcast (self ._rank_0_will_call_children_scripts )
155+ if self ._should_run_deadlock_detection ():
156+ self ._share_information_to_prevent_deadlock ()
157+
158+ self .accelerator .setup (trainer )
159+
160+ # move the model to the correct device
161+ self .model_to_device ()
162+
163+ if self ._layer_sync :
164+ self .model = self ._layer_sync .apply (self .model )
165+
166+ # skip wrapping the model if we are not fitting as no gradients need to be exchanged
167+ trainer_fn = trainer .state .fn
168+
169+ # set up optimizers after the module has been moved to the device
170+ # but before the module has been wrapped
171+ self .setup_optimizers (trainer )
172+ optimizers_to_device (self .optimizers , self .root_device )
173+
174+ if trainer_fn == TrainerFn .FITTING :
175+ self ._configure_bagua_model (trainer )
176+
177+ self .setup_precision_plugin ()
178+ self ._rank_0_will_call_children_scripts = self .broadcast (self ._rank_0_will_call_children_scripts )
179+ if self ._should_run_deadlock_detection ():
180+ self ._share_information_to_prevent_deadlock ()
181+
151182 def _check_qadam_optimizer (self ) -> None :
152183 has_qadam_optimizer = any ([isinstance (opt , QAdamOptimizer ) for opt in self .optimizers ])
153184
@@ -156,12 +187,12 @@ def _check_qadam_optimizer(self) -> None:
156187
157188 self ._bagua_kwargs ["q_adam_optimizer" ] = self .optimizers [0 ]
158189
159- def configure_ddp (self ) -> None :
190+ def _configure_bagua_model (self , trainer : "pl.Trainer" ) -> None :
160191 model = LightningBaguaModule (self .model ) # type: ignore[arg-type]
161192 self ._model = self ._setup_model (model )
162193
163194 # start the background communication for async algorithm
164- if self . lightning_module . trainer .training and self ._bagua_algorithm == "async" :
195+ if trainer .training and self ._bagua_algorithm == "async" :
165196 self .model .bagua_algorithm .resume (self .model ) # type: ignore
166197
167198 def _setup_model (self , model : Module ) -> BaguaDistributedDataParallel :
0 commit comments