Skip to content

Commit 25de488

Browse files
donlaparkotajcarmoccarohitgr7
authored
Fixes various typing errors in pytorch_lightning/strategies/deepspeed.py (#13832)
Co-authored-by: otaj <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent b37e466 commit 25de488

File tree

3 files changed

+71
-39
lines changed

3 files changed

+71
-39
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ module = [
6363
"pytorch_lightning.profilers.simple",
6464
"pytorch_lightning.strategies.ddp",
6565
"pytorch_lightning.strategies.ddp_spawn",
66-
"pytorch_lightning.strategies.deepspeed",
6766
"pytorch_lightning.strategies.fully_sharded",
6867
"pytorch_lightning.strategies.ipu",
6968
"pytorch_lightning.strategies.sharded",

src/pytorch_lightning/plugins/precision/deepspeed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
from pytorch_lightning.utilities.model_helpers import is_overridden
2727
from pytorch_lightning.utilities.warnings import WarningCache
2828

29+
_DEEPSPEED_AVAILABLE = _RequirementAvailable("deepspeed")
2930
_DEEPSPEED_GREATER_EQUAL_0_6 = _RequirementAvailable("deepspeed>=0.6.0")
3031
if TYPE_CHECKING:
31-
if pl.strategies.deepspeed._DEEPSPEED_AVAILABLE:
32+
if _DEEPSPEED_AVAILABLE:
3233
import deepspeed
3334

3435
warning_cache = WarningCache()

src/pytorch_lightning/strategies/deepspeed.py

Lines changed: 69 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import platform
2020
from collections import OrderedDict
2121
from pathlib import Path
22-
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
22+
from typing import Any, cast, Dict, Generator, List, Mapping, Optional, Tuple, Union
2323

2424
import torch
2525
from torch import Tensor
@@ -48,12 +48,12 @@
4848
from pytorch_lightning.utilities.optimizer import optimizers_to_device
4949
from pytorch_lightning.utilities.rank_zero import rank_zero_info
5050
from pytorch_lightning.utilities.seed import reset_seed
51-
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT
51+
from pytorch_lightning.utilities.types import _LRScheduler, _PATH, LRSchedulerConfig, ReduceLROnPlateau, STEP_OUTPUT
5252
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache
5353

5454
warning_cache = WarningCache()
5555

56-
_DEEPSPEED_AVAILABLE: bool = _RequirementAvailable("deepspeed")
56+
_DEEPSPEED_AVAILABLE = _RequirementAvailable("deepspeed")
5757
if _DEEPSPEED_AVAILABLE:
5858
import deepspeed
5959

@@ -76,7 +76,7 @@ def __init__(
7676
super().__init__(pl_module)
7777
self.precision = precision
7878

79-
def forward(self, *inputs, **kwargs):
79+
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
8080
inputs = apply_to_collection(inputs, Tensor, function=self._batch_to)
8181
return super().forward(*inputs, **kwargs)
8282

@@ -123,7 +123,7 @@ def __init__(
123123
reduce_bucket_size: int = 200_000_000,
124124
zero_allow_untested_optimizer: bool = True,
125125
logging_batch_size_per_gpu: Union[str, int] = "auto",
126-
config: Optional[Union[Path, str, dict]] = None,
126+
config: Optional[Union[_PATH, Dict[str, Any]]] = None,
127127
logging_level: int = logging.WARN,
128128
parallel_devices: Optional[List[torch.device]] = None,
129129
cluster_environment: Optional[ClusterEnvironment] = None,
@@ -142,7 +142,7 @@ def __init__(
142142
) -> None:
143143
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
144144
billion parameter models. `For more information: https://pytorch-
145-
lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#deepspeed`.
145+
lightning.readthedocs.io/en/stable/advanced/model_parallel.html#deepspeed`.
146146
147147
.. warning:: ``DeepSpeedStrategy`` is in beta and subject to change.
148148
@@ -331,7 +331,7 @@ def __init__(
331331
self.hysteresis = hysteresis
332332
self.min_loss_scale = min_loss_scale
333333

334-
def _load_config(self, config):
334+
def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
335335
if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
336336
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
337337
config = os.environ[self.DEEPSPEED_ENV_VAR]
@@ -342,9 +342,10 @@ def _load_config(self, config):
342342
)
343343
with open(config) as f:
344344
config = json.load(f)
345+
assert isinstance(config, dict) or config is None
345346
return config
346347

347-
def setup_distributed(self):
348+
def setup_distributed(self) -> None:
348349
reset_seed()
349350

350351
# determine which process we are and world size
@@ -357,8 +358,10 @@ def setup_distributed(self):
357358
self._config_initialized = True
358359

359360
def setup(self, trainer: "pl.Trainer") -> None:
361+
assert self.accelerator is not None
360362
self.accelerator.setup(trainer)
361363
# we set the device so that optimizers can be created with distributed comms.
364+
assert self.lightning_module is not None
362365
self.lightning_module._device = self.root_device
363366
self.setup_optimizers(trainer)
364367
self.setup_precision_plugin()
@@ -367,6 +370,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
367370
self.barrier()
368371

369372
def _init_deepspeed_distributed(self) -> None:
373+
assert self.cluster_environment is not None
370374
if platform.system() != "Windows":
371375
# do not set env variables on windows, allow deepspeed to control setup
372376
self._set_node_environment_variables()
@@ -378,14 +382,15 @@ def _init_deepspeed_distributed(self) -> None:
378382
self._process_group_backend = self._get_process_group_backend()
379383
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
380384

381-
def _get_process_group_backend(self):
385+
def _get_process_group_backend(self) -> str:
382386
return (
383387
self._process_group_backend
384388
or _get_process_group_backend_from_env()
385389
or get_default_process_group_backend_for_device(self.root_device)
386390
)
387391

388392
def _set_node_environment_variables(self) -> None:
393+
assert self.cluster_environment is not None
389394
os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
390395
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
391396
os.environ["RANK"] = str(self.global_rank)
@@ -396,7 +401,9 @@ def _set_node_environment_variables(self) -> None:
396401
def restore_checkpoint_after_setup(self) -> bool:
397402
return True
398403

399-
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
404+
def _setup_model_and_optimizers(
405+
self, model: Module, optimizers: List[Optimizer]
406+
) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]:
400407
"""Setup a model and multiple optimizers together.
401408
402409
Currently only a single optimizer is supported.
@@ -414,14 +421,18 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
414421
# train_micro_batch_size_per_gpu is used for throughput logging purposes
415422
# normally we set this to the batch size, but it is not available here unless the user provides it
416423
# as part of the config
424+
assert self.config is not None
417425
self.config.setdefault("train_micro_batch_size_per_gpu", 1)
418426
self.model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
419427
self._set_deepspeed_activation_checkpointing()
420428
return self.model, [optimizer]
421429

422430
def _setup_model_and_optimizer(
423-
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None
424-
):
431+
self,
432+
model: Module,
433+
optimizer: Optional[Optimizer],
434+
lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None,
435+
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
425436
"""Initialize one model and one optimizer with an optional learning rate scheduler.
426437
427438
This calls :func:`deepspeed.initialize` internally.
@@ -431,14 +442,15 @@ def _setup_model_and_optimizer(
431442
args=argparse.Namespace(device_rank=self.root_device.index),
432443
config=self.config,
433444
model=model,
434-
model_parameters=model_parameters, # type: ignore
445+
model_parameters=model_parameters,
435446
optimizer=optimizer,
436447
lr_scheduler=lr_scheduler,
437448
dist_init_required=False,
438449
)
439450
return deepspeed_engine, deepspeed_optimizer
440451

441-
def init_deepspeed(self):
452+
def init_deepspeed(self) -> None:
453+
assert self.lightning_module is not None
442454
# deepspeed handles gradient clipping internally
443455
if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
444456
rank_zero_warn(
@@ -464,6 +476,7 @@ def init_deepspeed(self):
464476
"DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
465477
)
466478

479+
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
467480
model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision_plugin.precision)
468481

469482
if self.lightning_module.trainer and self.lightning_module.trainer.training:
@@ -472,6 +485,7 @@ def init_deepspeed(self):
472485
self._initialize_deepspeed_inference(model)
473486

474487
def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig], Optional[int]]:
488+
assert self.lightning_module is not None
475489
optimizers, lr_schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module)
476490
if len(optimizers) > 1 or len(lr_schedulers) > 1:
477491
raise MisconfigurationException(
@@ -485,10 +499,13 @@ def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig], Opti
485499

486500
@property
487501
def zero_stage_3(self) -> bool:
488-
return self.config.get("zero_optimization") and self.config.get("zero_optimization").get("stage") == 3
502+
assert isinstance(self.config, dict)
503+
zero_optimization = self.config.get("zero_optimization")
504+
return zero_optimization is not None and zero_optimization.get("stage") == 3
489505

490-
def _initialize_deepspeed_train(self, model):
506+
def _initialize_deepspeed_train(self, model: Module) -> None:
491507
optimizer, scheduler = None, None
508+
assert isinstance(self.config, dict)
492509
if "optimizer" in self.config:
493510
rank_zero_info(
494511
"You have specified an optimizer and/or scheduler within the DeepSpeed config."
@@ -538,7 +555,8 @@ def model_sharded_context(self) -> Generator[None, None, None]:
538555
with model_parallel_context:
539556
yield
540557

541-
def _set_deepspeed_activation_checkpointing(self):
558+
def _set_deepspeed_activation_checkpointing(self) -> None:
559+
assert isinstance(self.config, dict)
542560
if self.config.get("activation_checkpointing"):
543561
checkpoint_config = self.config["activation_checkpointing"]
544562
deepspeed.checkpointing.configure(
@@ -549,8 +567,9 @@ def _set_deepspeed_activation_checkpointing(self):
549567
profile=checkpoint_config.get("profile"),
550568
)
551569

552-
def _initialize_deepspeed_inference(self, model):
570+
def _initialize_deepspeed_inference(self, model: Module) -> None:
553571
# todo: Currently DeepSpeed requires optimizers at inference to partition weights correctly
572+
assert isinstance(self.config, dict)
554573
optimizer, scheduler = None, None
555574
if "optimizer" not in self.config:
556575
rank_zero_info(
@@ -585,13 +604,15 @@ def _initialize_deepspeed_inference(self, model):
585604
self.model = model
586605

587606
@property
588-
def lightning_module(self):
607+
def lightning_module(self) -> Optional["pl.LightningModule"]:
589608
# the model may not be wrapped with DeepEngine & LightningDeepSpeedModule if calling this too early
590609
module = getattr(self.model, "module", self.model)
591-
return module.module if isinstance(module, LightningDeepSpeedModule) else module
610+
module = module.module if isinstance(module, LightningDeepSpeedModule) else module
611+
assert isinstance(module, pl.LightningModule) or module is None
612+
return module
592613

593614
@property
594-
def distributed_sampler_kwargs(self):
615+
def distributed_sampler_kwargs(self) -> Dict[str, int]:
595616
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
596617
return distributed_sampler_kwargs
597618

@@ -616,17 +637,18 @@ def handles_gradient_accumulation(self) -> bool:
616637
"""Whether the plugin handles gradient accumulation internally."""
617638
return True
618639

619-
def _format_config(self):
640+
def _format_config(self) -> None:
620641
if self.config is None:
621642
raise MisconfigurationException(
622643
"To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
623-
" See: https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#deepspeed"
644+
" See: https://pytorch-lightning.readthedocs.io/en/stable/advanced/model_parallel.html#deepspeed"
624645
)
625646
self._format_batch_size_and_grad_accum_config()
626647
self._format_precision_config()
627648

628-
def _format_batch_size_and_grad_accum_config(self):
649+
def _format_batch_size_and_grad_accum_config(self) -> None:
629650
# todo: using lite, we do not support these variables within the config
651+
assert isinstance(self.config, dict)
630652
if self.lightning_module is None:
631653
return
632654

@@ -642,16 +664,17 @@ def _format_batch_size_and_grad_accum_config(self):
642664
if "gradient_clipping" not in self.config:
643665
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0
644666

645-
def _auto_select_batch_size(self):
667+
def _auto_select_batch_size(self) -> int:
646668
# train_micro_batch_size_per_gpu is used for throughput logging purposes
647669
# by default we try to use the batch size of the loader
670+
assert self.lightning_module is not None
648671
batch_size = 1
649672
train_dl_source = self.lightning_module.trainer._data_connector._train_dataloader_source
650673
if train_dl_source.is_defined():
651674
try:
652675
train_dataloader = train_dl_source.dataloader()
653676
if hasattr(train_dataloader, "batch_sampler"):
654-
batch_size = train_dataloader.batch_sampler.batch_size
677+
batch_size = train_dataloader.batch_sampler.batch_size # type: ignore[union-attr]
655678
# broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup`
656679
# to have been called before
657680
except Exception:
@@ -664,6 +687,7 @@ def _auto_select_batch_size(self):
664687
return batch_size
665688

666689
def _format_precision_config(self) -> None:
690+
assert isinstance(self.config, dict)
667691
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
668692
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
669693
# FP16 is a DeepSpeed standalone AMP implementation
@@ -707,7 +731,7 @@ def _create_default_config(
707731
single_submit: bool,
708732
overlap_events: bool,
709733
thread_count: int,
710-
**zero_kwargs,
734+
**zero_kwargs: Any,
711735
) -> Dict:
712736
cfg = {
713737
"activation_checkpointing": {
@@ -753,7 +777,7 @@ def _create_default_config(
753777
return cfg
754778

755779
@property
756-
def deepspeed_engine(self):
780+
def deepspeed_engine(self) -> "deepspeed.DeepSpeedEngine":
757781
return self.model
758782

759783
@property
@@ -786,7 +810,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Op
786810
"When saving the DeepSpeed Stage 3 checkpoint, "
787811
"each worker will save a shard of the checkpoint within a directory. "
788812
"If a single file is required after training, "
789-
"see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#"
813+
"see https://pytorch-lightning.readthedocs.io/en/stable/advanced/model_parallel.html#"
790814
"deepspeed-zero-stage-3-single-file for instructions."
791815
)
792816
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
@@ -799,10 +823,12 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
799823
if self.load_full_weights and self.zero_stage_3:
800824
# Broadcast to ensure we load from the rank 0 checkpoint
801825
# This doesn't have to be the case when using deepspeed sharded checkpointing
802-
checkpoint_path = self.broadcast(checkpoint_path)
826+
checkpoint_path = cast(_PATH, self.broadcast(checkpoint_path))
803827
return super().load_checkpoint(checkpoint_path)
804828

805829
# Rely on deepspeed to load the checkpoint and necessary information
830+
assert self.lightning_module is not None
831+
806832
from pytorch_lightning.trainer.states import TrainerFn
807833

808834
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
@@ -818,6 +844,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
818844

819845
@property
820846
def lightning_restore_optimizer(self) -> bool:
847+
assert self.lightning_module is not None
821848
# managed by DeepSpeed
822849
if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
823850
rank_zero_warn(
@@ -842,11 +869,13 @@ def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None:
842869
ckpt: The ckpt file.
843870
"""
844871

845-
def load(module: torch.nn.Module, prefix=""):
872+
assert self.lightning_module is not None
873+
874+
def load(module: torch.nn.Module, prefix: str = "") -> None:
846875

847-
missing_keys = []
848-
unexpected_keys = []
849-
error_msgs = []
876+
missing_keys: List[str] = []
877+
unexpected_keys: List[str] = []
878+
error_msgs: List[str] = []
850879
state_dict = ckpt["state_dict"]
851880

852881
# copy state_dict so _load_from_state_dict can modify it
@@ -914,14 +943,17 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
914943
offload_optimizer_device="nvme",
915944
)
916945

917-
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
946+
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
947+
assert self.model is not None
918948
with self.precision_plugin.val_step_context():
919949
return self.model(*args, **kwargs)
920950

921-
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
951+
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
952+
assert self.model is not None
922953
with self.precision_plugin.test_step_context():
923954
return self.model(*args, **kwargs)
924955

925-
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
956+
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
957+
assert self.model is not None
926958
with self.precision_plugin.predict_step_context():
927959
return self.model(*args, **kwargs)

0 commit comments

Comments
 (0)