Skip to content

Commit a456b1b

Browse files
committed
update bagua
1 parent 0ab298f commit a456b1b

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

pytorch_lightning/strategies/bagua.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
from pytorch_lightning.plugins.precision import PrecisionPlugin
1313
from pytorch_lightning.strategies.ddp import DDPStrategy
1414
from pytorch_lightning.strategies.strategy import TBroadcast
15+
from pytorch_lightning.trainer.states import TrainerFn
1516
from pytorch_lightning.utilities.distributed import ReduceOp
1617
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1718
from pytorch_lightning.utilities.imports import _BAGUA_AVAILABLE
19+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
1820
from pytorch_lightning.utilities.seed import reset_seed
1921

2022
if _BAGUA_AVAILABLE:
@@ -148,6 +150,35 @@ def _set_node_environment_variables(self) -> None:
148150
os.environ["WORLD_SIZE"] = str(self.world_size)
149151
os.environ["LOCAL_RANK"] = str(self.local_rank)
150152

153+
def setup(self, trainer: "pl.Trainer") -> None:
154+
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
155+
if self._should_run_deadlock_detection():
156+
self._share_information_to_prevent_deadlock()
157+
158+
self.accelerator.setup(trainer)
159+
160+
# move the model to the correct device
161+
self.model_to_device()
162+
163+
if self._layer_sync:
164+
self.model = self._layer_sync.apply(self.model)
165+
166+
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
167+
trainer_fn = trainer.state.fn
168+
169+
# set up optimizers after the module has been moved to the device
170+
# but before the module has been wrapped
171+
self.setup_optimizers(trainer)
172+
optimizers_to_device(self.optimizers, self.root_device)
173+
174+
if trainer_fn == TrainerFn.FITTING:
175+
self._configure_bagua_model(trainer)
176+
177+
self.setup_precision_plugin()
178+
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
179+
if self._should_run_deadlock_detection():
180+
self._share_information_to_prevent_deadlock()
181+
151182
def _check_qadam_optimizer(self) -> None:
152183
has_qadam_optimizer = any([isinstance(opt, QAdamOptimizer) for opt in self.optimizers])
153184

@@ -156,12 +187,12 @@ def _check_qadam_optimizer(self) -> None:
156187

157188
self._bagua_kwargs["q_adam_optimizer"] = self.optimizers[0]
158189

159-
def configure_ddp(self) -> None:
190+
def _configure_bagua_model(self, trainer: "pl.Trainer") -> None:
160191
model = LightningBaguaModule(self.model) # type: ignore[arg-type]
161192
self._model = self._setup_model(model)
162193

163194
# start the background communication for async algorithm
164-
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
195+
if trainer.training and self._bagua_algorithm == "async":
165196
self.model.bagua_algorithm.resume(self.model) # type: ignore
166197

167198
def _setup_model(self, model: Module) -> BaguaDistributedDataParallel:

tests/strategies/test_bagua_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def test_configuration(algorithm, tmpdir):
8585
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
8686
if algorithm == "qadam":
8787
with pytest.raises(MisconfigurationException, match="Bagua QAdam can only accept one QAdamOptimizer"):
88-
trainer.strategy.configure_ddp()
88+
trainer.strategy._configure_bagua_model()
8989
else:
90-
trainer.strategy.configure_ddp()
90+
trainer.strategy._configure_bagua_model()
9191

9292

9393
@RunIf(bagua=True, min_gpus=1)

0 commit comments

Comments
 (0)