@@ -102,6 +102,7 @@ def __init__(
102102 cpu_checkpointing : bool = False ,
103103 contiguous_memory_optimization : bool = False ,
104104 synchronize_checkpoint_boundary : bool = False ,
105+ save_full_weights : bool = True ,
105106 ) -> None :
106107 """
107108
@@ -177,11 +178,16 @@ def __init__(
177178 Not supported by all models
178179
179180 synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.
181+
182+ save_full_weights: Gathers weights across all processes before saving to disk
183+ when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
184+ rather than individual sharded weight files.
185+ Disable to save sharded states individually. (Default: True)
180186 """
181187 if not _DEEPSPEED_AVAILABLE :
182188 raise MisconfigurationException (
183189 "To use the DeepSpeed plugin, you must have DeepSpeed installed."
184- " pip install deepspeed mpi4py "
190+ " pip install deepspeed"
185191 )
186192 super ().__init__ (
187193 parallel_devices = parallel_devices , num_nodes = num_nodes , cluster_environment = cluster_environment
@@ -205,11 +211,13 @@ def __init__(
205211 allgather_partitions = allgather_partitions ,
206212 reduce_scatter = reduce_scatter ,
207213 allgather_bucket_size = allgather_bucket_size ,
208- reduce_bucket_size = reduce_bucket_size
214+ reduce_bucket_size = reduce_bucket_size ,
209215 )
210216 self ._config_initialized = False
211217 deepspeed .utils .logging .logger .setLevel (logging_level )
212218
219+ self .save_full_weights = save_full_weights
220+
213221 # default FP16 parameters.
214222 self .loss_scale = loss_scale
215223 self .initial_scale_power = initial_scale_power
@@ -472,17 +480,27 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
472480 """Save model/training states as a checkpoint file through state-dump and file-write.
473481
474482 Args:
483+ checkpoint: The checkpoint state dictionary
475484 filepath: write-target file's path
476- weights_only: saving model weights only
477485 """
478486 if self .world_size > 1 and self .zero_stage_3 :
487+ if self .save_full_weights :
488+ # todo: expose this as general function in deepspeed
489+ state_dict = self .deepspeed_engine ._zero3_consolidated_fp16_state_dict ()
490+ if self .is_global_zero :
491+ # State dict keys will include reference to wrapper LightningDeepSpeedModule
492+ # Delete `module` prefix before saving.
493+ state_dict = {k .partition ('module.' )[2 ]: state_dict [k ] for k in state_dict .keys ()}
494+ checkpoint ['state_dict' ] = state_dict
495+ return super ().save_checkpoint (checkpoint , filepath )
496+ return
497+
479498 # Use deepspeed's internal checkpointing function to handle partitioned weights across processes
480499 # dump states as a checkpoint dictionary object
481500 save_dir = self ._filepath_to_dir (filepath )
482501 _exclude_keys = ['state_dict' , 'optimizer_states' , 'lr_schedulers' ]
483502 checkpoint = {k : v for k , v in checkpoint .items () if k not in _exclude_keys }
484503 self .deepspeed_engine .save_checkpoint (save_dir , client_state = checkpoint )
485-
486504 else :
487505 super ().save_checkpoint (checkpoint , filepath )
488506
@@ -491,7 +509,8 @@ def restore_model_state_from_ckpt_path(
491509 ckpt_path : str ,
492510 map_location : Callable = lambda storage , loc : storage ,
493511 ) -> Tuple [Dict , bool ]:
494- if self .world_size > 1 :
512+ if not self .save_full_weights and self .world_size > 1 :
513+ # Rely on deepspeed to load the checkpoint and necessary information
495514 from pytorch_lightning .trainer .states import TrainerState
496515 stage_is_fit = self .lightning_module .trainer .state == TrainerState .FITTING
497516 save_dir = self ._filepath_to_dir (ckpt_path )
@@ -511,6 +530,10 @@ def restore_model_state_from_ckpt_path(
511530 # hook: give user access to checkpoint if needed.
512531 self .lightning_module .on_load_checkpoint (client_state )
513532 return client_state , False
533+
534+ # Broadcast to ensure we load from the rank 0 checkpoint
535+ # This doesn't have to be the case when using deepspeed sharded checkpointing
536+ ckpt_path = self .broadcast (ckpt_path )
514537 return super ().restore_model_state_from_ckpt_path (ckpt_path , map_location = map_location )
515538
516539 def update_global_step (self , total_batch_idx : int , current_global_step : int ) -> int :
0 commit comments