Skip to content

Commit 0ab298f

Browse files
committed
Update fully_sharded.py
1 parent c2a7fd5 commit 0ab298f

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

pytorch_lightning/strategies/fully_sharded.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,15 @@ def setup_distributed(self) -> None:
135135

136136
def setup(self, trainer: "pl.Trainer") -> None:
137137
self.accelerator.setup(trainer)
138-
self.setup_optimizers(trainer)
139-
self.setup_precision_plugin()
140-
optimizers_to_device(self.optimizers, self.root_device)
141138

142139
if self._layer_sync:
143140
self.model = self._layer_sync.apply(self.model)
144141

145142
self.configure_ddp()
146143
self.barrier()
147144
self.setup_optimizers(trainer)
145+
optimizers_to_device(self.optimizers, self.root_device)
146+
self.setup_precision_plugin()
148147

149148
@contextlib.contextmanager
150149
def model_sharded_context(self) -> Generator:
@@ -181,9 +180,6 @@ def configure_ddp(self) -> None:
181180
# (TODO: need to figure out solution)
182181
self.model_to_device()
183182

184-
# setup optimizers after fully sharded has wrapped the lightning module
185-
self.setup_optimizers(self.lightning_module.trainer)
186-
187183
def model_to_device(self) -> None:
188184
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
189185
# ensure we update the device type in the lightning module

0 commit comments

Comments
 (0)