Skip to content

Commit 1dc5957

Browse files
awaelchlininginthecloud
authored andcommitted
Restrict setup methods to accept a single model (Lightning-AI#10064)
1 parent 22d6923 commit 1dc5957

File tree

5 files changed

+29
-53
lines changed

5 files changed

+29
-53
lines changed

CHANGELOG.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
211211
- LightningLite:
212212
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
213213
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018), [#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022))
214-
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
214+
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064))
215215
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
216-
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
217-
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028))
216+
* Implemented `DeepSpeedPlugin._setup_model_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064))
217+
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_model_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064))
218218
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))
219219

220220

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -379,30 +379,28 @@ def pre_dispatch(self):
379379
self.init_deepspeed()
380380
self.barrier()
381381

382-
def _setup_models_and_optimizers(
383-
self, models: List[Module], optimizers: List[Optimizer]
384-
) -> Tuple[List[Module], List[Optimizer]]:
385-
"""Setup multiple models and multiple optimizers together.
382+
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
383+
"""Setup a model and multiple optimizers together.
386384
387-
Currently only one model paired with a single optimizer is supported.
385+
Currently only a single optimizer is supported.
388386
389387
Return:
390-
A list with one model wrapped into a :class:`deepspeed.DeepSpeedEngine` and list with a single
388+
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
391389
deepspeed optimizer.
392390
"""
393-
if not (len(models) == len(optimizers) == 1):
391+
if len(optimizers) != 1:
394392
raise ValueError(
395-
f"Currently only one model and one optimizer is supported with DeepSpeed."
396-
f" Got {len(models)} models and {len(optimizers)} optimizers instead."
393+
f"Currently only one optimizer is supported with DeepSpeed."
394+
f" Got {len(optimizers)} optimizers instead."
397395
)
398396

399397
# train_micro_batch_size_per_gpu is used for throughput logging purposes
400398
# normally we set this to the batch size, but it is not available here unless the user provides it
401399
# as part of the config
402400
self.config.setdefault("train_micro_batch_size_per_gpu", 1)
403-
self._model, optimizer = self._setup_model_and_optimizer(models[0], optimizers[0])
401+
self._model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
404402
self._set_deepspeed_activation_checkpointing()
405-
return [self._model], [optimizer]
403+
return self._model, [optimizer]
406404

407405
def _setup_model_and_optimizer(
408406
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None

pytorch_lightning/plugins/training_type/sharded.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,33 +47,23 @@ def configure_ddp(self) -> None:
4747
# For multi-node training, enabling bucketing will improve performance.
4848
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
4949

50-
[self._model], optimizers = self._setup_models_and_optimizers(
51-
models=[LightningShardedDataParallel(self.model)],
50+
self._model, optimizers = self._setup_model_and_optimizers(
51+
model=LightningShardedDataParallel(self.model),
5252
optimizers=trainer.optimizers,
5353
)
5454
trainer.optimizers = optimizers
5555
trainer.convert_to_lightning_optimizers()
5656

57-
def _setup_models_and_optimizers(
58-
self, models: List[Module], optimizers: List[Optimizer]
59-
) -> Tuple[List[Module], List[Optimizer]]:
57+
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
6058
"""Wraps the model and optimizers with fairscale components.
6159
62-
Currently only one model can be setup at once.
63-
6460
Return:
65-
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
61+
The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
6662
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
6763
"""
68-
if len(models) > 1:
69-
raise ValueError(
70-
"DDPSharded only supports setting up a single model with one or several optimizers."
71-
f" Got {len(models)} models."
72-
)
73-
7464
optimizers = self._wrap_optimizers(optimizers)
75-
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
76-
return [model], optimizers
65+
model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs)
66+
return model, optimizers
7767

7868
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
7969
for x, optimizer in enumerate(optimizers):

pytorch_lightning/plugins/training_type/sharded_spawn.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,32 +39,22 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
3939

4040
def configure_ddp(self) -> None:
4141
trainer = self.lightning_module.trainer
42-
[self._model], optimizers = self._setup_models_and_optimizers(
43-
models=[LightningShardedDataParallel(self.model)],
42+
self._model, optimizers = self._setup_model_and_optimizers(
43+
model=LightningShardedDataParallel(self.model),
4444
optimizers=trainer.optimizers,
4545
)
4646
trainer.optimizers = optimizers
4747

48-
def _setup_models_and_optimizers(
49-
self, models: List[Module], optimizers: List[Optimizer]
50-
) -> Tuple[List[Module], List[Optimizer]]:
48+
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
5149
"""Wraps the model and optimizers with fairscale components.
5250
53-
Currently only one model can be setup at once.
54-
5551
Return:
56-
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
52+
The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
5753
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
5854
"""
59-
if len(models) > 1:
60-
raise ValueError(
61-
f"DDPShardedSpawn only supports setting up a single model with one or several optimizers."
62-
f" Got {len(models)} models."
63-
)
64-
6555
optimizers = self._wrap_optimizers(optimizers)
66-
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
67-
return [model], optimizers
56+
model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs)
57+
return model, optimizers
6858

6959
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
7060
for x, optimizer in enumerate(optimizers):

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,16 @@ def setup_environment(self) -> None:
6161
def setup(self) -> None:
6262
"""Called by the accelerator to finish setup."""
6363

64-
def _setup_models_and_optimizers(
65-
self, models: List[Module], optimizers: List[Optimizer]
66-
) -> Tuple[List[Module], List[Optimizer]]:
67-
"""Setup multiple models and multiple optimizers together.
64+
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
65+
"""Setup a model and multiple optimizers together.
6866
6967
The returned objects are expected to be in the same order they were passed in. The default implementation will
70-
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists.
68+
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs.
7169
"""
7270
# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324
73-
models = [self._setup_model(model) for model in models]
71+
model = self._setup_model(model)
7472
optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers]
75-
return models, optimizers
73+
return model, optimizers
7674

7775
def _setup_model(self, model: Module) -> Module:
7876
"""Performs setup for the model, e.g., by wrapping it by another class."""

0 commit comments

Comments
 (0)