Skip to content

Commit 3ea5347

Browse files
awaelchliSean Narenpre-commit-ci[bot]
authored
Update setup logic in training type plugins (deepspeed) [2 / n] (#10009)
Co-authored-by: Sean Naren <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e8beceb commit 3ea5347

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
207207
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
208208
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
209209
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
210+
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
210211

211212
### Changed
212213

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union
2323

2424
import torch
25+
from torch.nn import Module
2526
from torch.optim import Optimizer
27+
from torch.optim.lr_scheduler import _LRScheduler
2628

2729
import pytorch_lightning as pl
2830
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
@@ -377,6 +379,50 @@ def pre_dispatch(self):
377379
self.init_deepspeed()
378380
self.barrier()
379381

382+
def _setup_models_and_optimizers(
383+
self, models: List[Module], optimizers: List[Optimizer]
384+
) -> Tuple[List[Module], List[Optimizer]]:
385+
"""Setup multiple models and multiple optimizers together.
386+
387+
Currently only one model paired with a single optimizer is supported.
388+
389+
Return:
390+
A list with one model wrapped into a :class:`deepspeed.DeepSpeedEngine` and list with a single
391+
deepspeed optimizer.
392+
"""
393+
if not (len(models) == len(optimizers) == 1):
394+
raise ValueError(
395+
f"Currently only one model and one optimizer is supported with DeepSpeed."
396+
f" Got {len(models)} models and {len(optimizers)} optimizers instead."
397+
)
398+
399+
# train_micro_batch_size_per_gpu is used for throughput logging purposes
400+
# normally we set this to the batch size, but it is not available here unless the user provides it
401+
# as part of the config
402+
self.config.setdefault("train_micro_batch_size_per_gpu", 1)
403+
self._model, optimizer = self._setup_model_and_optimizer(models[0], optimizers[0])
404+
self._set_deepspeed_activation_checkpointing()
405+
return [self._model], [optimizer]
406+
407+
def _setup_model_and_optimizer(
408+
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None
409+
):
410+
"""Initialize one model and one optimizer with an optional learning rate scheduler.
411+
412+
This calls :func:`deepspeed.initialize` internally.
413+
"""
414+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
415+
deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
416+
args=argparse.Namespace(device_rank=self.root_device.index),
417+
config=self.config,
418+
model=model,
419+
model_parameters=model_parameters, # type: ignore
420+
optimizer=optimizer,
421+
lr_scheduler=lr_scheduler,
422+
dist_init_required=False,
423+
)
424+
return deepspeed_engine, deepspeed_optimizer
425+
380426
def init_deepspeed(self):
381427
# check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles
382428
# gradient clipping internally
@@ -441,18 +487,7 @@ def _initialize_deepspeed_train(self, model):
441487
optimizer, lr_scheduler, _ = self._init_optimizers()
442488

443489
scheduler = lr_scheduler["scheduler"]
444-
445-
model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
446-
model, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
447-
args=argparse.Namespace(device_rank=self.root_device.index),
448-
config=self.config,
449-
model=model,
450-
model_parameters=model_parameters,
451-
optimizer=optimizer,
452-
lr_scheduler=scheduler,
453-
dist_init_required=False,
454-
)
455-
490+
model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
456491
self._set_deepspeed_activation_checkpointing()
457492

458493
# although we set these here, deepspeed manages the specific optimizer logic
@@ -568,6 +603,10 @@ def _format_config(self):
568603
self._format_precision_config()
569604

570605
def _format_batch_size_and_grad_accum_config(self):
606+
# todo: using lite, we do not support these variables within the config
607+
if self.lightning_module is None:
608+
return
609+
571610
if "gradient_accumulation_steps" in self.config:
572611
raise MisconfigurationException(
573612
"Do not set `gradient_accumulation_steps` in the DeepSpeed config"

0 commit comments

Comments
 (0)