1919import platform
2020from collections import OrderedDict
2121from pathlib import Path
22- from typing import Any , Dict , Generator , List , Mapping , Optional , Tuple , Union
22+ from typing import Any , cast , Dict , Generator , List , Mapping , Optional , Tuple , Union
2323
2424import torch
2525from torch import Tensor
4848from pytorch_lightning .utilities .optimizer import optimizers_to_device
4949from pytorch_lightning .utilities .rank_zero import rank_zero_info
5050from pytorch_lightning .utilities .seed import reset_seed
51- from pytorch_lightning .utilities .types import _PATH , LRSchedulerConfig , LRSchedulerTypeUnion , STEP_OUTPUT
51+ from pytorch_lightning .utilities .types import _LRScheduler , _PATH , LRSchedulerConfig , ReduceLROnPlateau , STEP_OUTPUT
5252from pytorch_lightning .utilities .warnings import rank_zero_warn , WarningCache
5353
5454warning_cache = WarningCache ()
5555
56- _DEEPSPEED_AVAILABLE : bool = _RequirementAvailable ("deepspeed" )
56+ _DEEPSPEED_AVAILABLE = _RequirementAvailable ("deepspeed" )
5757if _DEEPSPEED_AVAILABLE :
5858 import deepspeed
5959
@@ -76,7 +76,7 @@ def __init__(
7676 super ().__init__ (pl_module )
7777 self .precision = precision
7878
79- def forward (self , * inputs , ** kwargs ) :
79+ def forward (self , * inputs : Any , ** kwargs : Any ) -> Any :
8080 inputs = apply_to_collection (inputs , Tensor , function = self ._batch_to )
8181 return super ().forward (* inputs , ** kwargs )
8282
@@ -123,7 +123,7 @@ def __init__(
123123 reduce_bucket_size : int = 200_000_000 ,
124124 zero_allow_untested_optimizer : bool = True ,
125125 logging_batch_size_per_gpu : Union [str , int ] = "auto" ,
126- config : Optional [Union [Path , str , dict ]] = None ,
126+ config : Optional [Union [_PATH , Dict [ str , Any ] ]] = None ,
127127 logging_level : int = logging .WARN ,
128128 parallel_devices : Optional [List [torch .device ]] = None ,
129129 cluster_environment : Optional [ClusterEnvironment ] = None ,
@@ -142,7 +142,7 @@ def __init__(
142142 ) -> None :
143143 """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
144144 billion parameter models. `For more information: https://pytorch-
145- lightning.readthedocs.io/en/latest /advanced/advanced_gpu .html#deepspeed`.
145+ lightning.readthedocs.io/en/stable /advanced/model_parallel .html#deepspeed`.
146146
147147 .. warning:: ``DeepSpeedStrategy`` is in beta and subject to change.
148148
@@ -331,7 +331,7 @@ def __init__(
331331 self .hysteresis = hysteresis
332332 self .min_loss_scale = min_loss_scale
333333
334- def _load_config (self , config ) :
334+ def _load_config (self , config : Optional [ Union [ _PATH , Dict [ str , Any ]]]) -> Optional [ Dict [ str , Any ]] :
335335 if config is None and self .DEEPSPEED_ENV_VAR in os .environ :
336336 rank_zero_info (f"Loading DeepSpeed config from set { self .DEEPSPEED_ENV_VAR } environment variable" )
337337 config = os .environ [self .DEEPSPEED_ENV_VAR ]
@@ -342,9 +342,10 @@ def _load_config(self, config):
342342 )
343343 with open (config ) as f :
344344 config = json .load (f )
345+ assert isinstance (config , dict ) or config is None
345346 return config
346347
347- def setup_distributed (self ):
348+ def setup_distributed (self ) -> None :
348349 reset_seed ()
349350
350351 # determine which process we are and world size
@@ -357,8 +358,10 @@ def setup_distributed(self):
357358 self ._config_initialized = True
358359
359360 def setup (self , trainer : "pl.Trainer" ) -> None :
361+ assert self .accelerator is not None
360362 self .accelerator .setup (trainer )
361363 # we set the device so that optimizers can be created with distributed comms.
364+ assert self .lightning_module is not None
362365 self .lightning_module ._device = self .root_device
363366 self .setup_optimizers (trainer )
364367 self .setup_precision_plugin ()
@@ -367,6 +370,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
367370 self .barrier ()
368371
369372 def _init_deepspeed_distributed (self ) -> None :
373+ assert self .cluster_environment is not None
370374 if platform .system () != "Windows" :
371375 # do not set env variables on windows, allow deepspeed to control setup
372376 self ._set_node_environment_variables ()
@@ -378,14 +382,15 @@ def _init_deepspeed_distributed(self) -> None:
378382 self ._process_group_backend = self ._get_process_group_backend ()
379383 deepspeed .init_distributed (self ._process_group_backend , distributed_port = self .cluster_environment .main_port )
380384
381- def _get_process_group_backend (self ):
385+ def _get_process_group_backend (self ) -> str :
382386 return (
383387 self ._process_group_backend
384388 or _get_process_group_backend_from_env ()
385389 or get_default_process_group_backend_for_device (self .root_device )
386390 )
387391
388392 def _set_node_environment_variables (self ) -> None :
393+ assert self .cluster_environment is not None
389394 os .environ ["MASTER_ADDR" ] = self .cluster_environment .main_address
390395 os .environ ["MASTER_PORT" ] = str (self .cluster_environment .main_port )
391396 os .environ ["RANK" ] = str (self .global_rank )
@@ -396,7 +401,9 @@ def _set_node_environment_variables(self) -> None:
396401 def restore_checkpoint_after_setup (self ) -> bool :
397402 return True
398403
399- def _setup_model_and_optimizers (self , model : Module , optimizers : List [Optimizer ]) -> Tuple [Module , List [Optimizer ]]:
404+ def _setup_model_and_optimizers (
405+ self , model : Module , optimizers : List [Optimizer ]
406+ ) -> Tuple ["deepspeed.DeepSpeedEngine" , List [Optimizer ]]:
400407 """Setup a model and multiple optimizers together.
401408
402409 Currently only a single optimizer is supported.
@@ -414,14 +421,18 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
414421 # train_micro_batch_size_per_gpu is used for throughput logging purposes
415422 # normally we set this to the batch size, but it is not available here unless the user provides it
416423 # as part of the config
424+ assert self .config is not None
417425 self .config .setdefault ("train_micro_batch_size_per_gpu" , 1 )
418426 self .model , optimizer = self ._setup_model_and_optimizer (model , optimizers [0 ])
419427 self ._set_deepspeed_activation_checkpointing ()
420428 return self .model , [optimizer ]
421429
422430 def _setup_model_and_optimizer (
423- self , model : Module , optimizer : Optimizer , lr_scheduler : Optional [LRSchedulerTypeUnion ] = None
424- ):
431+ self ,
432+ model : Module ,
433+ optimizer : Optional [Optimizer ],
434+ lr_scheduler : Optional [Union [_LRScheduler , ReduceLROnPlateau ]] = None ,
435+ ) -> Tuple ["deepspeed.DeepSpeedEngine" , Optimizer ]:
425436 """Initialize one model and one optimizer with an optional learning rate scheduler.
426437
427438 This calls :func:`deepspeed.initialize` internally.
@@ -431,14 +442,15 @@ def _setup_model_and_optimizer(
431442 args = argparse .Namespace (device_rank = self .root_device .index ),
432443 config = self .config ,
433444 model = model ,
434- model_parameters = model_parameters , # type: ignore
445+ model_parameters = model_parameters ,
435446 optimizer = optimizer ,
436447 lr_scheduler = lr_scheduler ,
437448 dist_init_required = False ,
438449 )
439450 return deepspeed_engine , deepspeed_optimizer
440451
441- def init_deepspeed (self ):
452+ def init_deepspeed (self ) -> None :
453+ assert self .lightning_module is not None
442454 # deepspeed handles gradient clipping internally
443455 if is_overridden ("configure_gradient_clipping" , self .lightning_module , pl .LightningModule ):
444456 rank_zero_warn (
@@ -464,6 +476,7 @@ def init_deepspeed(self):
464476 "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
465477 )
466478
479+ assert isinstance (self .model , (pl .LightningModule , _LightningPrecisionModuleWrapperBase ))
467480 model = LightningDeepSpeedModule (pl_module = self .model , precision = self .precision_plugin .precision )
468481
469482 if self .lightning_module .trainer and self .lightning_module .trainer .training :
@@ -472,6 +485,7 @@ def init_deepspeed(self):
472485 self ._initialize_deepspeed_inference (model )
473486
474487 def _init_optimizers (self ) -> Tuple [Optimizer , Optional [LRSchedulerConfig ], Optional [int ]]:
488+ assert self .lightning_module is not None
475489 optimizers , lr_schedulers , optimizer_frequencies = _init_optimizers_and_lr_schedulers (self .lightning_module )
476490 if len (optimizers ) > 1 or len (lr_schedulers ) > 1 :
477491 raise MisconfigurationException (
@@ -485,10 +499,13 @@ def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig], Opti
485499
486500 @property
487501 def zero_stage_3 (self ) -> bool :
488- return self .config .get ("zero_optimization" ) and self .config .get ("zero_optimization" ).get ("stage" ) == 3
502+ assert isinstance (self .config , dict )
503+ zero_optimization = self .config .get ("zero_optimization" )
504+ return zero_optimization is not None and zero_optimization .get ("stage" ) == 3
489505
490- def _initialize_deepspeed_train (self , model ) :
506+ def _initialize_deepspeed_train (self , model : Module ) -> None :
491507 optimizer , scheduler = None , None
508+ assert isinstance (self .config , dict )
492509 if "optimizer" in self .config :
493510 rank_zero_info (
494511 "You have specified an optimizer and/or scheduler within the DeepSpeed config."
@@ -538,7 +555,8 @@ def model_sharded_context(self) -> Generator[None, None, None]:
538555 with model_parallel_context :
539556 yield
540557
541- def _set_deepspeed_activation_checkpointing (self ):
558+ def _set_deepspeed_activation_checkpointing (self ) -> None :
559+ assert isinstance (self .config , dict )
542560 if self .config .get ("activation_checkpointing" ):
543561 checkpoint_config = self .config ["activation_checkpointing" ]
544562 deepspeed .checkpointing .configure (
@@ -549,8 +567,9 @@ def _set_deepspeed_activation_checkpointing(self):
549567 profile = checkpoint_config .get ("profile" ),
550568 )
551569
552- def _initialize_deepspeed_inference (self , model ) :
570+ def _initialize_deepspeed_inference (self , model : Module ) -> None :
553571 # todo: Currently DeepSpeed requires optimizers at inference to partition weights correctly
572+ assert isinstance (self .config , dict )
554573 optimizer , scheduler = None , None
555574 if "optimizer" not in self .config :
556575 rank_zero_info (
@@ -585,13 +604,15 @@ def _initialize_deepspeed_inference(self, model):
585604 self .model = model
586605
587606 @property
588- def lightning_module (self ):
607+ def lightning_module (self ) -> Optional [ "pl.LightningModule" ] :
589608 # the model may not be wrapped with DeepEngine & LightningDeepSpeedModule if calling this too early
590609 module = getattr (self .model , "module" , self .model )
591- return module .module if isinstance (module , LightningDeepSpeedModule ) else module
610+ module = module .module if isinstance (module , LightningDeepSpeedModule ) else module
611+ assert isinstance (module , pl .LightningModule ) or module is None
612+ return module
592613
593614 @property
594- def distributed_sampler_kwargs (self ):
615+ def distributed_sampler_kwargs (self ) -> Dict [ str , int ] :
595616 distributed_sampler_kwargs = dict (num_replicas = self .world_size , rank = self .global_rank )
596617 return distributed_sampler_kwargs
597618
@@ -616,17 +637,18 @@ def handles_gradient_accumulation(self) -> bool:
616637 """Whether the plugin handles gradient accumulation internally."""
617638 return True
618639
619- def _format_config (self ):
640+ def _format_config (self ) -> None :
620641 if self .config is None :
621642 raise MisconfigurationException (
622643 "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
623- " See: https://pytorch-lightning.readthedocs.io/en/latest /advanced/advanced_gpu .html#deepspeed"
644+ " See: https://pytorch-lightning.readthedocs.io/en/stable /advanced/model_parallel .html#deepspeed"
624645 )
625646 self ._format_batch_size_and_grad_accum_config ()
626647 self ._format_precision_config ()
627648
628- def _format_batch_size_and_grad_accum_config (self ):
649+ def _format_batch_size_and_grad_accum_config (self ) -> None :
629650 # todo: using lite, we do not support these variables within the config
651+ assert isinstance (self .config , dict )
630652 if self .lightning_module is None :
631653 return
632654
@@ -642,16 +664,17 @@ def _format_batch_size_and_grad_accum_config(self):
642664 if "gradient_clipping" not in self .config :
643665 self .config ["gradient_clipping" ] = self .lightning_module .trainer .gradient_clip_val or 0.0
644666
645- def _auto_select_batch_size (self ):
667+ def _auto_select_batch_size (self ) -> int :
646668 # train_micro_batch_size_per_gpu is used for throughput logging purposes
647669 # by default we try to use the batch size of the loader
670+ assert self .lightning_module is not None
648671 batch_size = 1
649672 train_dl_source = self .lightning_module .trainer ._data_connector ._train_dataloader_source
650673 if train_dl_source .is_defined ():
651674 try :
652675 train_dataloader = train_dl_source .dataloader ()
653676 if hasattr (train_dataloader , "batch_sampler" ):
654- batch_size = train_dataloader .batch_sampler .batch_size
677+ batch_size = train_dataloader .batch_sampler .batch_size # type: ignore[union-attr]
655678 # broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup`
656679 # to have been called before
657680 except Exception :
@@ -664,6 +687,7 @@ def _auto_select_batch_size(self):
664687 return batch_size
665688
666689 def _format_precision_config (self ) -> None :
690+ assert isinstance (self .config , dict )
667691 if self .precision_plugin .precision in (PrecisionType .HALF , PrecisionType .MIXED ):
668692 if "fp16" not in self .config and self .precision_plugin .amp_type == AMPType .NATIVE :
669693 # FP16 is a DeepSpeed standalone AMP implementation
@@ -707,7 +731,7 @@ def _create_default_config(
707731 single_submit : bool ,
708732 overlap_events : bool ,
709733 thread_count : int ,
710- ** zero_kwargs ,
734+ ** zero_kwargs : Any ,
711735 ) -> Dict :
712736 cfg = {
713737 "activation_checkpointing" : {
@@ -753,7 +777,7 @@ def _create_default_config(
753777 return cfg
754778
755779 @property
756- def deepspeed_engine (self ):
780+ def deepspeed_engine (self ) -> "deepspeed.DeepSpeedEngine" :
757781 return self .model
758782
759783 @property
@@ -786,7 +810,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Op
786810 "When saving the DeepSpeed Stage 3 checkpoint, "
787811 "each worker will save a shard of the checkpoint within a directory. "
788812 "If a single file is required after training, "
789- "see https://pytorch-lightning.readthedocs.io/en/latest /advanced/advanced_gpu .html#"
813+ "see https://pytorch-lightning.readthedocs.io/en/stable /advanced/model_parallel .html#"
790814 "deepspeed-zero-stage-3-single-file for instructions."
791815 )
792816 # Use deepspeed's internal checkpointing function to handle partitioned weights across processes
@@ -799,10 +823,12 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
799823 if self .load_full_weights and self .zero_stage_3 :
800824 # Broadcast to ensure we load from the rank 0 checkpoint
801825 # This doesn't have to be the case when using deepspeed sharded checkpointing
802- checkpoint_path = self .broadcast (checkpoint_path )
826+ checkpoint_path = cast ( _PATH , self .broadcast (checkpoint_path ) )
803827 return super ().load_checkpoint (checkpoint_path )
804828
805829 # Rely on deepspeed to load the checkpoint and necessary information
830+ assert self .lightning_module is not None
831+
806832 from pytorch_lightning .trainer .states import TrainerFn
807833
808834 is_fitting = self .lightning_module .trainer .state .fn == TrainerFn .FITTING
@@ -818,6 +844,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
818844
819845 @property
820846 def lightning_restore_optimizer (self ) -> bool :
847+ assert self .lightning_module is not None
821848 # managed by DeepSpeed
822849 if self .load_full_weights and self .zero_stage_3 and self .lightning_module .trainer .state .fn == TrainerFn .FITTING :
823850 rank_zero_warn (
@@ -842,11 +869,13 @@ def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None:
842869 ckpt: The ckpt file.
843870 """
844871
845- def load (module : torch .nn .Module , prefix = "" ):
872+ assert self .lightning_module is not None
873+
874+ def load (module : torch .nn .Module , prefix : str = "" ) -> None :
846875
847- missing_keys = []
848- unexpected_keys = []
849- error_msgs = []
876+ missing_keys : List [ str ] = []
877+ unexpected_keys : List [ str ] = []
878+ error_msgs : List [ str ] = []
850879 state_dict = ckpt ["state_dict" ]
851880
852881 # copy state_dict so _load_from_state_dict can modify it
@@ -914,14 +943,17 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
914943 offload_optimizer_device = "nvme" ,
915944 )
916945
917- def validation_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
946+ def validation_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
947+ assert self .model is not None
918948 with self .precision_plugin .val_step_context ():
919949 return self .model (* args , ** kwargs )
920950
921- def test_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
951+ def test_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
952+ assert self .model is not None
922953 with self .precision_plugin .test_step_context ():
923954 return self .model (* args , ** kwargs )
924955
925- def predict_step (self , * args , ** kwargs ) -> STEP_OUTPUT :
956+ def predict_step (self , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
957+ assert self .model is not None
926958 with self .precision_plugin .predict_step_context ():
927959 return self .model (* args , ** kwargs )
0 commit comments