1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14-
14+ import contextlib
1515import json
1616import logging
1717import os
18+ from collections import OrderedDict
1819from pathlib import Path
1920from types import SimpleNamespace
20- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
21+ from typing import Any , Callable , Dict , Generator , List , Optional , Tuple , Union
2122
2223import torch
23- from torch .nn .parallel import DistributedDataParallel
2424
25+ from pytorch_lightning .callbacks import GradientAccumulationScheduler
2526from pytorch_lightning .core .lightning import LightningModule
2627from pytorch_lightning .overrides .base import _LightningModuleWrapperBase
2728from pytorch_lightning .plugins .environments .cluster_environment import ClusterEnvironment
3738 import deepspeed
3839
3940
41+ def remove_module_hooks (model : torch .nn .Module ) -> None :
42+ # todo (tchaton) awaiting this feature to move upstream to DeepSpeed
43+ for module in model .modules ():
44+ module ._backward_hooks = OrderedDict ()
45+ module ._is_full_backward_hook = None
46+ module ._forward_hooks = OrderedDict ()
47+ module ._forward_pre_hooks = OrderedDict ()
48+ module ._state_dict_hooks = OrderedDict ()
49+ module ._load_state_dict_pre_hooks = OrderedDict ()
50+
51+
4052class LightningDeepSpeedModule (_LightningModuleWrapperBase ):
4153
4254 def __init__ (self , pl_module : LightningModule , precision : int ):
@@ -67,6 +79,8 @@ def __init__(
6779 zero_optimization : bool = True ,
6880 stage : int = 2 ,
6981 cpu_offload : bool = False ,
82+ cpu_offload_params : bool = False ,
83+ cpu_offload_use_pin_memory : bool = False ,
7084 contiguous_gradients : bool = True ,
7185 overlap_comm : bool = True ,
7286 allgather_partitions : bool = True ,
@@ -80,10 +94,14 @@ def __init__(
8094 parallel_devices : Optional [List [torch .device ]] = None ,
8195 cluster_environment : Optional [ClusterEnvironment ] = None ,
8296 loss_scale : float = 0 ,
83- initial_scale_power : int = 32 ,
97+ initial_scale_power : int = 16 ,
8498 loss_scale_window : int = 1000 ,
8599 hysteresis : int = 2 ,
86- min_loss_scale : int = 1
100+ min_loss_scale : int = 1 ,
101+ partition_activations : bool = False ,
102+ cpu_checkpointing : bool = False ,
103+ contiguous_memory_optimization : bool = False ,
104+ synchronize_checkpoint_boundary : bool = False ,
87105 ) -> None :
88106 """
89107
@@ -106,6 +124,10 @@ def __init__(
106124
107125 cpu_offload: Enable offloading optimizer memory and computation to CPU
108126
127+ cpu_offload_params: When using ZeRO stage 3, offload parameters to CPU
128+
129+ cpu_offload_use_pin_memory: When using ZeRO stage 3, pin memory on CPU
130+
109131 contiguous_gradients: Copies gradients to a continuous buffer as they are produced.
110132 Avoids memory fragmentation during backwards. Useful when training large models. (default: True)
111133
@@ -144,6 +166,17 @@ def __init__(
144166
145167 min_loss_scale: The minimum FP16 dynamic loss scaling value (Default: 1000)
146168
169+ partition_activations: Enables partition activation when used with ZeRO stage 3.
170+ Still requires you to wrap your forward functions in deepspeed.checkpointing.checkpoint.
171+ See `deepspeed tutorial
172+ <https://www.deepspeed.ai/tutorials/megatron/#deepspeed-activation-checkpoints-optional>`_
173+
174+ cpu_checkpointing: Offloads partitioned activations to CPU if ``partition_activations`` is enabled
175+
176+ contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory.
177+ Not supported by all models
178+
179+ synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.
147180 """
148181 if not _DEEPSPEED_AVAILABLE :
149182 raise MisconfigurationException (
@@ -159,8 +192,14 @@ def __init__(
159192 self .config = self ._create_default_config (
160193 zero_optimization ,
161194 zero_allow_untested_optimizer ,
195+ partition_activations = partition_activations ,
196+ cpu_checkpointing = cpu_checkpointing ,
197+ contiguous_memory_optimization = contiguous_memory_optimization ,
198+ synchronize_checkpoint_boundary = synchronize_checkpoint_boundary ,
162199 stage = stage ,
163200 cpu_offload = cpu_offload ,
201+ cpu_offload_params = cpu_offload_params ,
202+ cpu_offload_use_pin_memory = cpu_offload_use_pin_memory ,
164203 contiguous_gradients = contiguous_gradients ,
165204 overlap_comm = overlap_comm ,
166205 allgather_partitions = allgather_partitions ,
@@ -200,9 +239,14 @@ def init_deepspeed(self):
200239 self ._format_config ()
201240 self ._config_initialized = True
202241
242+ self ._handle_gradient_accumulation_steps ()
243+
203244 precision = self .lightning_module .trainer .accelerator .precision
204245 model = LightningDeepSpeedModule (pl_module = self .model , precision = precision )
205246
247+ if self .on_gpu :
248+ torch .cuda .set_device (self .root_device )
249+
206250 if self .lightning_module .trainer and self .lightning_module .trainer .training :
207251 self ._initialize_deepspeed_train (model )
208252 else :
@@ -220,9 +264,11 @@ def _init_scheduler_optimizer(self):
220264 optimizer = optimizers [0 ]
221265 return optimizer , scheduler , optimizer_frequencies
222266
267+ @property
268+ def zero_stage_3 (self ) -> bool :
269+ return self .config .get ('zero_optimization' ) and self .config .get ('zero_optimization' ).get ('stage' ) == 3
270+
223271 def _initialize_deepspeed_train (self , model ):
224- if self .on_gpu :
225- torch .cuda .set_device (self .root_device )
226272 optimizer , lightning_scheduler , optimizer_frequencies = None , None , None
227273 if "optimizer" not in self .config :
228274 rank_zero_info (
@@ -239,21 +285,65 @@ def _initialize_deepspeed_train(self, model):
239285 lr_scheduler = lightning_scheduler ,
240286 config_params = self .config ,
241287 )
288+ self ._set_deepspeed_activation_checkpointing ()
242289
243290 # set optimizer for save/load, but deepspeed manages the specific optimizer logic
244291 self .lightning_module .trainer .optimizers = [optimizer ]
292+ self .lightning_module .trainer .schedulers = [lr_scheduler ]
245293 self .model = model
246294
295+ @contextlib .contextmanager
296+ def model_sharded_context (self ) -> Generator [None , None , None ]:
297+ if self .zero_stage_3 :
298+ model_parallel_context = deepspeed .zero .Init (remote_device = "cpu" , pin_memory = True )
299+ else :
300+ model_parallel_context = super ().model_sharded_context ()
301+
302+ with model_parallel_context :
303+ yield
304+
305+ def _set_deepspeed_activation_checkpointing (self ):
306+ if self .config .get ('activation_checkpointing' ):
307+ checkpoint_config = self .config ['activation_checkpointing' ]
308+ deepspeed .checkpointing .configure (
309+ mpu_ = None ,
310+ partition_activations = checkpoint_config .get ('partition_activations' ),
311+ contiguous_checkpointing = checkpoint_config .get ('contiguous_checkpointing' ),
312+ checkpoint_in_cpu = checkpoint_config .get ('checkpoint_in_cpu' ),
313+ profile = checkpoint_config .get ('profile' ),
314+ )
315+
247316 def _initialize_deepspeed_inference (self , model ):
248- # move the model to the correct device
249- self .model_to_device ()
250-
251- self .pre_configure_ddp ()
252- self .model = DistributedDataParallel (
253- model ,
254- device_ids = self .determine_ddp_device_ids (),
255- ** self ._ddp_kwargs ,
317+ # todo: Currently DeepSpeed requires optimizers at inference to partition weights correctly
318+ optimizer , lightning_scheduler , optimizer_frequencies = None , None , None
319+ if "optimizer" not in self .config :
320+ rank_zero_info (
321+ "You have not specified an optimizer or scheduler within the DeepSpeed config."
322+ "Using `configure_optimizers` to define optimizer and scheduler."
323+ )
324+ optimizer , lightning_scheduler , optimizer_frequencies = self ._init_scheduler_optimizer ()
325+ inference_config = {
326+ # todo: this is required for DeepSpeed throughput timers, or throughput timers will be incorrect
327+ 'train_micro_batch_size_per_gpu' : 1 ,
328+ }
329+ if 'fp16' in self .config :
330+ inference_config .update ({"fp16" : self .config ["fp16" ]})
331+ if self .zero_stage_3 :
332+ inference_config .update ({
333+ "zero_allow_untested_optimizer" : self .config ['zero_allow_untested_optimizer' ],
334+ "zero_optimization" : self .config ['zero_optimization' ],
335+ })
336+ # Remove all module hooks before initializing new model
337+ remove_module_hooks (model )
338+ model , _ , _ , _ = deepspeed .initialize (
339+ args = SimpleNamespace (local_rank = self .local_rank ),
340+ model = model ,
341+ optimizer = optimizer ,
342+ lr_scheduler = lightning_scheduler ,
343+ config_params = inference_config ,
344+ model_parameters = [],
256345 )
346+ self .model = model
257347
258348 def configure_scheduler (self , lr_scheduler ):
259349 scheduler = _get_default_scheduler_config ()
@@ -282,6 +372,20 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Calla
282372 # internally, the engine has a reference to the optimizer already.
283373 self .model .step (** kwargs )
284374
375+ def _handle_gradient_accumulation_steps (self ):
376+ """
377+ This functions overrides the trainer.accumulation_scheduler to generate
378+ ``accumulate_grad_batches=1``.
379+ Therefore, ``optimizer_step`` will be called on every batches seen
380+ so DeepSpeed Engine handles the gradient accumulation logic internally.
381+ """
382+ if self .config .get ("gradient_accumulation_steps" ) > 1 :
383+ self ._original_accumulate_grad_batches = self .lightning_module .trainer .accumulate_grad_batches
384+ # todo (tchaton) Add support for accumulate_grad_batches being a dictionary.
385+ self .lightning_module .trainer .accumulation_scheduler = GradientAccumulationScheduler ({0 : 1 })
386+ else :
387+ self ._original_accumulate_grad_batches = None
388+
285389 def _format_config (self ):
286390 if self .config is None :
287391 raise MisconfigurationException (
@@ -300,14 +404,13 @@ def _format_batch_size_and_grad_accum_config(self):
300404 if "train_micro_batch_size_per_gpu" not in self .config :
301405 # train_micro_batch_size_per_gpu is used for throughput logging purposes
302406 # by default we use the batch size of the loader which may be incorrect if a batch sampler is passed
303- batch_size = self .lightning_module .train_dataloader ().batch_size
407+ batch_size = self .lightning_module .train_dataloader ().batch_sampler . batch_size
304408 self .config ["train_micro_batch_size_per_gpu" ] = batch_size
305409 self .config ["gradient_accumulation_steps" ] = self .lightning_module .trainer .accumulate_grad_batches
306410 if "gradient_clipping" not in self .config :
307411 self .config ["gradient_clipping" ] = self .lightning_module .trainer .gradient_clip_val
308412
309413 def _format_precision_config (self ):
310-
311414 amp_type = self .lightning_module .trainer .accelerator_connector .amp_type
312415 amp_level = self .lightning_module .trainer .accelerator_connector .amp_level
313416 precision = self .lightning_module .trainer .accelerator_connector .precision
@@ -333,8 +436,87 @@ def _format_precision_config(self):
333436 raise MisconfigurationException ("To use DeepSpeed ZeRO Optimization, you must set precision=16." )
334437
335438 def _create_default_config (
336- self , zero_optimization : bool , zero_allow_untested_optimizer : bool , ** zero_kwargs
439+ self ,
440+ zero_optimization : bool ,
441+ zero_allow_untested_optimizer : bool ,
442+ partition_activations : bool ,
443+ cpu_checkpointing : bool ,
444+ contiguous_memory_optimization : bool ,
445+ synchronize_checkpoint_boundary : bool ,
446+ ** zero_kwargs ,
337447 ) -> Dict :
448+ cfg = {
449+ 'activation_checkpointing' : {
450+ "partition_activations" : partition_activations ,
451+ "cpu_checkpointing" : cpu_checkpointing ,
452+ "contiguous_memory_optimization" : contiguous_memory_optimization ,
453+ "synchronize_checkpoint_boundary" : synchronize_checkpoint_boundary
454+ }
455+ }
338456 if zero_optimization :
339- return {"zero_allow_untested_optimizer" : zero_allow_untested_optimizer , "zero_optimization" : zero_kwargs }
340- return {}
457+ cfg = {
458+ "zero_allow_untested_optimizer" : zero_allow_untested_optimizer ,
459+ "zero_optimization" : zero_kwargs ,
460+ ** cfg
461+ }
462+ return cfg
463+
464+ def _filepath_to_dir (self , filepath : str ) -> str :
465+ return os .path .dirname (filepath )
466+
467+ @property
468+ def deepspeed_engine (self ):
469+ return self .model
470+
471+ def save_checkpoint (self , checkpoint : Dict , filepath : str ) -> None :
472+ """Save model/training states as a checkpoint file through state-dump and file-write.
473+
474+ Args:
475+ filepath: write-target file's path
476+ weights_only: saving model weights only
477+ """
478+ if self .world_size > 1 and self .zero_stage_3 :
479+ # Use deepspeed's internal checkpointing function to handle partitioned weights across processes
480+ # dump states as a checkpoint dictionary object
481+ save_dir = self ._filepath_to_dir (filepath )
482+ _exclude_keys = ['state_dict' , 'optimizer_states' , 'lr_schedulers' ]
483+ checkpoint = {k : v for k , v in checkpoint .items () if k not in _exclude_keys }
484+ self .deepspeed_engine .save_checkpoint (save_dir , client_state = checkpoint )
485+
486+ else :
487+ super ().save_checkpoint (checkpoint , filepath )
488+
489+ def restore_model_state_from_ckpt_path (
490+ self ,
491+ ckpt_path : str ,
492+ map_location : Callable = lambda storage , loc : storage ,
493+ ) -> Tuple [Dict , bool ]:
494+ if self .world_size > 1 :
495+ from pytorch_lightning .trainer .states import TrainerState
496+ stage_is_fit = self .lightning_module .trainer .state == TrainerState .FITTING
497+ save_dir = self ._filepath_to_dir (ckpt_path )
498+
499+ if self .zero_stage_3 :
500+ # TODO: Currently required as this call is missing within the deepspeed engine.
501+ self .deepspeed_engine .optimizer ._partition_all_parameters ()
502+
503+ _ , client_state = self .deepspeed_engine .load_checkpoint (
504+ save_dir , load_optimizer_states = stage_is_fit , load_lr_scheduler_states = stage_is_fit
505+ )
506+
507+ # restore datamodule states
508+ if self .lightning_module .trainer .datamodule is not None :
509+ self .lightning_module .trainer .datamodule .on_load_checkpoint (client_state )
510+
511+ # hook: give user access to checkpoint if needed.
512+ self .lightning_module .on_load_checkpoint (client_state )
513+ return client_state , False
514+ return super ().restore_model_state_from_ckpt_path (ckpt_path , map_location = map_location )
515+
516+ def update_global_step (self , total_batch_idx : int , current_global_step : int ) -> int :
517+ if self ._original_accumulate_grad_batches is None :
518+ return super ().update_global_step (total_batch_idx , current_global_step )
519+ else :
520+ if total_batch_idx % self ._original_accumulate_grad_batches == 0 :
521+ current_global_step += 1
522+ return current_global_step
0 commit comments