Skip to content

Commit 215a9c9

Browse files
BloodAxekaushikb11
authored andcommitted
Fix DPP + SyncBN (Lightning-AI#6838)
* Fix DPP + SyncBN Ensure that model is already on correct GPU before applying SyncBN conversion * Fix order of SyncBN for ddp_spawn
1 parent 31b2d2b commit 215a9c9

File tree

3 files changed

+6
-9
lines changed

3 files changed

+6
-9
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,12 @@ def pre_dispatch(self):
257257
self.dist.rank = self.global_rank
258258
self.dist.device = self.root_device
259259

260-
if self.sync_batchnorm:
261-
self.model = self.configure_sync_batchnorm(self.model)
262-
263260
# move the model to the correct device
264261
self.model_to_device()
265262

263+
if self.sync_batchnorm:
264+
self.model = self.configure_sync_batchnorm(self.model)
265+
266266
self.configure_ddp()
267267

268268
self.barrier()

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ def new_process(self, process_idx, trainer, mp_queue):
148148
self.dist.rank = self.global_rank
149149
self.dist.device = self.root_device
150150

151-
if self.sync_batchnorm:
152-
self.model = self.configure_sync_batchnorm(self.model)
153-
154151
# move the model to the correct device
155152
self.model_to_device()
156153

154+
if self.sync_batchnorm:
155+
self.model = self.configure_sync_batchnorm(self.model)
156+
157157
self.configure_ddp()
158158

159159
self.barrier()

tests/trainer/test_dataloaders.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,6 @@ def __len__(self):
721721
assert has_len(dataloader)
722722
assert has_iterable_dataset(dataloader)
723723
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
724-
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
725-
trainer.validate(model, val_dataloaders=[dataloader])
726724
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
727725
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
728726
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
@@ -735,7 +733,6 @@ def __len__(self):
735733
assert not has_len(dataloader)
736734
assert has_iterable_dataset(dataloader)
737735
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
738-
trainer.validate(model, val_dataloaders=dataloader)
739736
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
740737
trainer.test(model, test_dataloaders=dataloader)
741738
trainer.predict(model, dataloaders=dataloader)

0 commit comments

Comments
 (0)