Skip to content

Commit a451997

Browse files
authored
Avoid wrapping LightningModule in DDP plugins when not fitting (#9096)
* Avoid wrapping LightningModule in DDP plugins when not fitting * Avoid wrapping LightningModule in DDP plugins when not fitting
1 parent e2ecb8f commit a451997

File tree

9 files changed

+163
-23
lines changed

9 files changed

+163
-23
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
285285
- Fixed a bug causing logging with `log_gpu_memory='min_max'` not working ([#9013](https://github.com/PyTorchLightning/pytorch-lightning/pull/9013))
286286

287287

288+
- Fixed wrapping issue: avoid wrapping LightningModule with data-parallel modules when not fitting in `DDPPlugin`, `DDPSpawnPlugin`, `DDPShardedPlugin`, `DDPSpawnShardedPlugin` ([#9096]https://github.com/PyTorchLightning/pytorch-lightning/pull/9096)
289+
290+
288291
## [1.4.3] - 2021-08-17
289292

290293
- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
)
5757
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
5858
from pytorch_lightning.utilities.seed import reset_seed
59+
from pytorch_lightning.utilities.types import STEP_OUTPUT
5960

6061
if _TORCH_GREATER_EQUAL_1_10:
6162
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
@@ -361,7 +362,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
361362
trainer.optimizers = optimizers
362363
trainer.convert_to_lightning_optimizers()
363364

364-
def configure_ddp(self):
365+
def configure_ddp(self) -> None:
365366
self.pre_configure_ddp()
366367
self._model = DistributedDataParallel(
367368
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
@@ -380,7 +381,10 @@ def pre_dispatch(self):
380381
if self.sync_batchnorm:
381382
self.model = self.configure_sync_batchnorm(self.model)
382383

383-
self.configure_ddp()
384+
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
385+
trainer_fn = self.lightning_module.trainer.state.fn
386+
if trainer_fn == TrainerFn.FITTING:
387+
self.configure_ddp()
384388

385389
# share ddp pids to all processes
386390
self._share_information_to_prevent_deadlock()
@@ -424,17 +428,22 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
424428
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
425429
return tensor
426430

427-
def training_step(self, *args, **kwargs):
431+
def training_step(self, *args, **kwargs) -> Optional[Any]:
428432
return self.model(*args, **kwargs)
429433

430-
def validation_step(self, *args, **kwargs):
431-
return self.model(*args, **kwargs)
434+
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
435+
if isinstance(self.model, DistributedDataParallel):
436+
# used when calling `trainer.fit`
437+
return self.model(*args, **kwargs)
438+
else:
439+
# used when calling `trainer.validate`
440+
return self.lightning_module.validation_step(*args, **kwargs)
432441

433-
def test_step(self, *args, **kwargs):
434-
return self.model(*args, **kwargs)
442+
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
443+
return self.lightning_module.test_step(*args, **kwargs)
435444

436-
def predict_step(self, *args, **kwargs):
437-
return self.model(*args, **kwargs)
445+
def predict_step(self, *args, **kwargs) -> Any:
446+
return self.lightning_module.predict_step(*args, **kwargs)
438447

439448
def post_training_step(self):
440449
if not self.lightning_module.automatic_optimization:

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
sync_ddp_if_available,
4747
)
4848
from pytorch_lightning.utilities.seed import reset_seed
49+
from pytorch_lightning.utilities.types import STEP_OUTPUT
4950

5051
if _TORCH_GREATER_EQUAL_1_8:
5152
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
@@ -201,7 +202,10 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ
201202
if self.sync_batchnorm:
202203
self.model = self.configure_sync_batchnorm(self.model)
203204

204-
self.configure_ddp()
205+
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
206+
trainer_fn = self.lightning_module.trainer.state.fn
207+
if trainer_fn == TrainerFn.FITTING:
208+
self.configure_ddp()
205209

206210
self.barrier()
207211

@@ -254,7 +258,7 @@ def _register_ddp_hooks(self) -> None:
254258
ddp_comm_wrapper=self._ddp_comm_wrapper,
255259
)
256260

257-
def configure_ddp(self):
261+
def configure_ddp(self) -> None:
258262
self.pre_configure_ddp()
259263
self._model = DistributedDataParallel(
260264
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
@@ -340,17 +344,22 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
340344
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
341345
return tensor
342346

343-
def training_step(self, *args, **kwargs):
347+
def training_step(self, *args, **kwargs) -> Optional[Any]:
344348
return self.model(*args, **kwargs)
345349

346-
def validation_step(self, *args, **kwargs):
347-
return self.model(*args, **kwargs)
350+
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
351+
if isinstance(self.model, DistributedDataParallel):
352+
# used when calling `trainer.fit`
353+
return self.model(*args, **kwargs)
354+
else:
355+
# used when calling `trainer.validate`
356+
return self.lightning_module.validation_step(*args, **kwargs)
348357

349-
def test_step(self, *args, **kwargs):
350-
return self.model(*args, **kwargs)
358+
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
359+
return self.lightning_module.test_step(*args, **kwargs)
351360

352-
def predict_step(self, *args, **kwargs):
353-
return self.model(*args, **kwargs)
361+
def predict_step(self, *args, **kwargs) -> Any:
362+
return self.lightning_module.predict_step(*args, **kwargs)
354363

355364
def post_training_step(self):
356365
if not self.lightning_module.automatic_optimization:

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,12 @@ def checkpoint_io(self) -> CheckpointIO:
818818
@checkpoint_io.setter
819819
def checkpoint_io(self, plugin: CheckpointIO) -> None:
820820
raise MisconfigurationException("DeepSpeed currently does not support custom checkpoint plugins.")
821+
822+
def validation_step(self, *args, **kwargs):
823+
return self.model(*args, **kwargs)
824+
825+
def test_step(self, *args, **kwargs):
826+
return self.model(*args, **kwargs)
827+
828+
def predict_step(self, *args, **kwargs):
829+
return self.model(*args, **kwargs)

pytorch_lightning/plugins/training_type/sharded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class DDPShardedPlugin(DDPPlugin):
3434

3535
_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M
3636

37-
def configure_ddp(self):
37+
def configure_ddp(self) -> None:
3838
self._wrap_optimizers()
3939
self._model = ShardedDataParallel(
4040
LightningShardedDataParallel(self.model),

pytorch_lightning/plugins/training_type/sharded_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
class DDPSpawnShardedPlugin(DDPSpawnPlugin):
3434
"""Optimizer sharded training provided by FairScale."""
3535

36-
def configure_ddp(self):
36+
def configure_ddp(self) -> None:
3737
self._wrap_optimizers()
3838
self._model = ShardedDataParallel(
3939
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers

tests/plugins/test_ddp_plugin.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
import torch
1919
from torch.nn.parallel import DistributedDataParallel
2020

21-
from pytorch_lightning import Trainer
21+
from pytorch_lightning import LightningModule, Trainer
2222
from pytorch_lightning.plugins import DDPPlugin
2323
from pytorch_lightning.plugins.environments import LightningEnvironment
24+
from pytorch_lightning.trainer.states import TrainerFn
2425
from tests.helpers.boring_model import BoringModel
2526
from tests.helpers.runif import RunIf
2627

@@ -94,3 +95,37 @@ def creates_children(self):
9495
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
9596
):
9697
trainer.fit(model)
98+
99+
100+
@RunIf(skip_windows=True)
101+
def test_ddp_configure_ddp():
102+
"""Tests with ddp plugin."""
103+
model = BoringModel()
104+
ddp_plugin = DDPPlugin()
105+
trainer = Trainer(
106+
max_epochs=1,
107+
plugins=[ddp_plugin],
108+
)
109+
# test wrap the model if fitting
110+
trainer.state.fn = TrainerFn.FITTING
111+
trainer.accelerator.connect(model)
112+
trainer.accelerator.setup_environment()
113+
trainer.accelerator.setup(trainer)
114+
trainer.lightning_module.trainer = trainer
115+
assert isinstance(trainer.model, LightningModule)
116+
trainer._pre_dispatch()
117+
# in DDPPlugin configure_ddp(), model wrapped by DistributedDataParallel
118+
assert isinstance(trainer.model, DistributedDataParallel)
119+
120+
trainer = Trainer(
121+
max_epochs=1,
122+
plugins=[ddp_plugin],
123+
)
124+
# test do not wrap the model if trainerFN is not fitting
125+
trainer.accelerator.connect(model)
126+
trainer.accelerator.setup_environment()
127+
trainer.accelerator.setup(trainer)
128+
trainer.lightning_module.trainer = trainer
129+
trainer._pre_dispatch()
130+
# in DDPPlugin configure_ddp(), model are still LightningModule
131+
assert isinstance(trainer.model, LightningModule)

tests/plugins/test_ddp_spawn_plugin.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import torch
15+
from torch.nn.parallel.distributed import DistributedDataParallel
1516

16-
from pytorch_lightning import Trainer
17+
from pytorch_lightning import LightningModule, Trainer
1718
from pytorch_lightning.plugins import DDPSpawnPlugin
19+
from pytorch_lightning.trainer.states import TrainerFn
1820
from tests.helpers.boring_model import BoringDataModule, BoringModel
1921
from tests.helpers.runif import RunIf
2022

@@ -77,3 +79,37 @@ def test_ddp_spawn_extra_parameters(tmpdir):
7779
trainer.fit(model, datamodule=dm)
7880
assert trainer.callback_metrics[val_name] == torch.tensor(val)
7981
assert model.test_val == "test_val"
82+
83+
84+
class BoringModelDDP(BoringModel):
85+
def on_train_start(self) -> None:
86+
"""Check if trainer module is wrapped as DistributedDataParallel during training stage."""
87+
assert isinstance(self.trainer.model, DistributedDataParallel)
88+
89+
def on_validation_start(self) -> None:
90+
"""Check if trainer module remains as LightningModule during test stage."""
91+
if self.trainer.state.fn == TrainerFn.FITTING:
92+
assert isinstance(self.trainer.model, DistributedDataParallel)
93+
else:
94+
assert isinstance(self.trainer.model, LightningModule)
95+
96+
def on_test_start(self) -> None:
97+
"""Check if trainer module remains as LightningModule during test stage."""
98+
assert isinstance(self.trainer.model, LightningModule)
99+
100+
def on_predict_start(self) -> None:
101+
"""Check if trainer module remains as LightningModule during prediction stage."""
102+
assert isinstance(self.trainer.model, LightningModule)
103+
104+
105+
@RunIf(skip_windows=True)
106+
def test_ddp_spawn_configure_ddp(tmpdir):
107+
"""Tests with ddp spawn plugin."""
108+
trainer = Trainer(default_root_dir=tmpdir, num_processes=2, accelerator="ddp_spawn", fast_dev_run=True)
109+
110+
model = BoringModelDDP()
111+
112+
trainer.fit(model)
113+
trainer.validate(model, dataloaders=model.val_dataloader())
114+
trainer.test(model, dataloaders=model.test_dataloader())
115+
trainer.predict(model, dataloaders=model.predict_dataloader())

tests/plugins/test_sharded_plugin.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,18 @@
44
import pytest
55
import torch
66

7-
from pytorch_lightning import Trainer
7+
from pytorch_lightning import LightningModule, Trainer
88
from pytorch_lightning.callbacks import Callback
99
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
10+
from pytorch_lightning.trainer.states import TrainerFn
11+
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
1012
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1113
from tests.helpers.boring_model import BoringModel
1214
from tests.helpers.runif import RunIf
1315

16+
if _FAIRSCALE_AVAILABLE:
17+
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
18+
1419

1520
@pytest.mark.parametrize("clip_val", [0, 10])
1621
@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True)
@@ -249,3 +254,37 @@ def test_ddp_sharded_plugin_manual_optimization(tmpdir):
249254
model = ManualBoringModel()
250255
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=2, gpus=2)
251256
trainer.fit(model)
257+
258+
259+
class BoringModelSharded(BoringModel):
260+
def on_train_start(self) -> None:
261+
"""Check if trainer module is wrapped as ShardedDataParallel during training stage."""
262+
assert isinstance(self.trainer.model, ShardedDataParallel)
263+
264+
def on_test_start(self) -> None:
265+
"""Check if trainer module remains as LightningModule during test stage."""
266+
assert isinstance(self.trainer.model, LightningModule)
267+
268+
def on_validation_start(self) -> None:
269+
"""Check if trainer module remains as LightningModule during test stage."""
270+
if self.trainer.state.fn == TrainerFn.FITTING:
271+
assert isinstance(self.trainer.model, ShardedDataParallel)
272+
else:
273+
assert isinstance(self.trainer.model, LightningModule)
274+
275+
def on_predict_start(self) -> None:
276+
"""Check if trainer module remains as LightningModule during prediction stage."""
277+
assert isinstance(self.trainer.model, LightningModule)
278+
279+
280+
@RunIf(skip_windows=True, fairscale=True)
281+
def test_configure_ddp(tmpdir):
282+
"""Tests with ddp sharded plugin."""
283+
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=True)
284+
285+
model = BoringModelSharded()
286+
287+
trainer.fit(model)
288+
trainer.test(model, dataloaders=model.test_dataloader())
289+
trainer.validate(model, dataloaders=model.val_dataloader())
290+
trainer.predict(model, dataloaders=model.predict_dataloader())

0 commit comments

Comments
 (0)