Skip to content

Commit 1302766

Browse files
tchatonSeanNarenawaelchliSeanNarencarmocca
authored
DeepSpeed ZeRO Update (#6546)
* Add context to call hook to handle all modules defined within the hook * Expose some additional parameters * Added docs, exposed parameters * Make sure we only configure if necessary * Setup activation checkpointing regardless, saves the user having to do it manually * Add some tests that fail currently * update * update * update * add tests * change docstring * resolve accumulate_grad_batches * resolve flake8 * Update DeepSpeed to use latest version, add some comments * add metrics * update * Small formatting fixes, clean up some code * Few cleanups * No need for default state * Fix tests, add some boilerplate that should move eventually * Add hook removal * Add a context manager to handle hook * Small naming cleanup * wip * move save_checkpoint responsability to accelerator * resolve flake8 * add BC * Change recommended scale to 16 * resolve flake8 * update test * update install * update * update test * update * update * update test * resolve flake8 * update * update * update on comments * Push * pull * Update pytorch_lightning/plugins/training_type/deepspeed.py Co-authored-by: Adrian Wälchli <[email protected]> * Update pytorch_lightning/plugins/training_type/deepspeed.py Co-authored-by: Adrian Wälchli <[email protected]> * update * Apply suggestions from code review * Swap to using world size defined by plugin * update * update todo * Remove deepspeed from extra, keep it in the base cuda docker install * Push * pull * update * update * update * update * Minor changes * duplicate * format * format2 Co-authored-by: SeanNaren <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 9876df1 commit 1302766

File tree

16 files changed

+549
-136
lines changed

16 files changed

+549
-136
lines changed

dockers/base-cuda/Dockerfile

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,6 @@ RUN \
113113
pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex && \
114114
rm -rf apex
115115

116-
RUN \
117-
# install DeepSpeed from source.
118-
# todo: swap to pypi release once DeepSpeed releases a new version >= 0.3.10
119-
pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb
120-
121116
RUN \
122117
# Show what we have
123118
pip --version && \

pytorch_lightning/accelerators/accelerator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def results(self) -> Any:
441441
return self.training_type_plugin.results
442442

443443
@contextlib.contextmanager
444-
def model_sharded_context(self) -> Generator:
444+
def model_sharded_context(self) -> Generator[None, None, None]:
445445
"""
446446
Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
447447
shard the model instantly - useful for extremely large models. Can save memory and
@@ -511,3 +511,6 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
511511
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
512512
"""
513513
return self.training_type_plugin.setup_optimizers_in_pre_dispatch
514+
515+
def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
516+
return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step)

pytorch_lightning/plugins/precision/double.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import wraps
15-
from typing import Any, Sequence, Tuple, TYPE_CHECKING, List
15+
from typing import Any, List, Sequence, Tuple, TYPE_CHECKING
1616

1717
import torch
1818

@@ -44,9 +44,7 @@ def _to_double_precision(data: torch.Tensor) -> torch.Tensor:
4444

4545
@staticmethod
4646
def _move_float_tensors_to_double(collection: Any) -> Any:
47-
return apply_to_collection(
48-
collection, torch.Tensor, function=_DoublePrecisionPatch._to_double_precision
49-
)
47+
return apply_to_collection(collection, torch.Tensor, function=_DoublePrecisionPatch._to_double_precision)
5048

5149
@classmethod
5250
def patch(cls, model: 'Module', method_name: str) -> '_DoublePrecisionPatch':

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 202 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import contextlib
1515
import json
1616
import logging
1717
import os
18+
from collections import OrderedDict
1819
from pathlib import Path
1920
from types import SimpleNamespace
20-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
2122

2223
import torch
23-
from torch.nn.parallel import DistributedDataParallel
2424

25+
from pytorch_lightning.callbacks import GradientAccumulationScheduler
2526
from pytorch_lightning.core.lightning import LightningModule
2627
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
2728
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
@@ -37,6 +38,17 @@
3738
import deepspeed
3839

3940

41+
def remove_module_hooks(model: torch.nn.Module) -> None:
42+
# todo (tchaton) awaiting this feature to move upstream to DeepSpeed
43+
for module in model.modules():
44+
module._backward_hooks = OrderedDict()
45+
module._is_full_backward_hook = None
46+
module._forward_hooks = OrderedDict()
47+
module._forward_pre_hooks = OrderedDict()
48+
module._state_dict_hooks = OrderedDict()
49+
module._load_state_dict_pre_hooks = OrderedDict()
50+
51+
4052
class LightningDeepSpeedModule(_LightningModuleWrapperBase):
4153

4254
def __init__(self, pl_module: LightningModule, precision: int):
@@ -67,6 +79,8 @@ def __init__(
6779
zero_optimization: bool = True,
6880
stage: int = 2,
6981
cpu_offload: bool = False,
82+
cpu_offload_params: bool = False,
83+
cpu_offload_use_pin_memory: bool = False,
7084
contiguous_gradients: bool = True,
7185
overlap_comm: bool = True,
7286
allgather_partitions: bool = True,
@@ -80,10 +94,14 @@ def __init__(
8094
parallel_devices: Optional[List[torch.device]] = None,
8195
cluster_environment: Optional[ClusterEnvironment] = None,
8296
loss_scale: float = 0,
83-
initial_scale_power: int = 32,
97+
initial_scale_power: int = 16,
8498
loss_scale_window: int = 1000,
8599
hysteresis: int = 2,
86-
min_loss_scale: int = 1
100+
min_loss_scale: int = 1,
101+
partition_activations: bool = False,
102+
cpu_checkpointing: bool = False,
103+
contiguous_memory_optimization: bool = False,
104+
synchronize_checkpoint_boundary: bool = False,
87105
) -> None:
88106
"""
89107
@@ -106,6 +124,10 @@ def __init__(
106124
107125
cpu_offload: Enable offloading optimizer memory and computation to CPU
108126
127+
cpu_offload_params: When using ZeRO stage 3, offload parameters to CPU
128+
129+
cpu_offload_use_pin_memory: When using ZeRO stage 3, pin memory on CPU
130+
109131
contiguous_gradients: Copies gradients to a continuous buffer as they are produced.
110132
Avoids memory fragmentation during backwards. Useful when training large models. (default: True)
111133
@@ -144,6 +166,17 @@ def __init__(
144166
145167
min_loss_scale: The minimum FP16 dynamic loss scaling value (Default: 1000)
146168
169+
partition_activations: Enables partition activation when used with ZeRO stage 3.
170+
Still requires you to wrap your forward functions in deepspeed.checkpointing.checkpoint.
171+
See `deepspeed tutorial
172+
<https://www.deepspeed.ai/tutorials/megatron/#deepspeed-activation-checkpoints-optional>`_
173+
174+
cpu_checkpointing: Offloads partitioned activations to CPU if ``partition_activations`` is enabled
175+
176+
contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory.
177+
Not supported by all models
178+
179+
synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.
147180
"""
148181
if not _DEEPSPEED_AVAILABLE:
149182
raise MisconfigurationException(
@@ -159,8 +192,14 @@ def __init__(
159192
self.config = self._create_default_config(
160193
zero_optimization,
161194
zero_allow_untested_optimizer,
195+
partition_activations=partition_activations,
196+
cpu_checkpointing=cpu_checkpointing,
197+
contiguous_memory_optimization=contiguous_memory_optimization,
198+
synchronize_checkpoint_boundary=synchronize_checkpoint_boundary,
162199
stage=stage,
163200
cpu_offload=cpu_offload,
201+
cpu_offload_params=cpu_offload_params,
202+
cpu_offload_use_pin_memory=cpu_offload_use_pin_memory,
164203
contiguous_gradients=contiguous_gradients,
165204
overlap_comm=overlap_comm,
166205
allgather_partitions=allgather_partitions,
@@ -200,9 +239,14 @@ def init_deepspeed(self):
200239
self._format_config()
201240
self._config_initialized = True
202241

242+
self._handle_gradient_accumulation_steps()
243+
203244
precision = self.lightning_module.trainer.accelerator.precision
204245
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)
205246

247+
if self.on_gpu:
248+
torch.cuda.set_device(self.root_device)
249+
206250
if self.lightning_module.trainer and self.lightning_module.trainer.training:
207251
self._initialize_deepspeed_train(model)
208252
else:
@@ -220,9 +264,11 @@ def _init_scheduler_optimizer(self):
220264
optimizer = optimizers[0]
221265
return optimizer, scheduler, optimizer_frequencies
222266

267+
@property
268+
def zero_stage_3(self) -> bool:
269+
return self.config.get('zero_optimization') and self.config.get('zero_optimization').get('stage') == 3
270+
223271
def _initialize_deepspeed_train(self, model):
224-
if self.on_gpu:
225-
torch.cuda.set_device(self.root_device)
226272
optimizer, lightning_scheduler, optimizer_frequencies = None, None, None
227273
if "optimizer" not in self.config:
228274
rank_zero_info(
@@ -239,21 +285,65 @@ def _initialize_deepspeed_train(self, model):
239285
lr_scheduler=lightning_scheduler,
240286
config_params=self.config,
241287
)
288+
self._set_deepspeed_activation_checkpointing()
242289

243290
# set optimizer for save/load, but deepspeed manages the specific optimizer logic
244291
self.lightning_module.trainer.optimizers = [optimizer]
292+
self.lightning_module.trainer.schedulers = [lr_scheduler]
245293
self.model = model
246294

295+
@contextlib.contextmanager
296+
def model_sharded_context(self) -> Generator[None, None, None]:
297+
if self.zero_stage_3:
298+
model_parallel_context = deepspeed.zero.Init(remote_device="cpu", pin_memory=True)
299+
else:
300+
model_parallel_context = super().model_sharded_context()
301+
302+
with model_parallel_context:
303+
yield
304+
305+
def _set_deepspeed_activation_checkpointing(self):
306+
if self.config.get('activation_checkpointing'):
307+
checkpoint_config = self.config['activation_checkpointing']
308+
deepspeed.checkpointing.configure(
309+
mpu_=None,
310+
partition_activations=checkpoint_config.get('partition_activations'),
311+
contiguous_checkpointing=checkpoint_config.get('contiguous_checkpointing'),
312+
checkpoint_in_cpu=checkpoint_config.get('checkpoint_in_cpu'),
313+
profile=checkpoint_config.get('profile'),
314+
)
315+
247316
def _initialize_deepspeed_inference(self, model):
248-
# move the model to the correct device
249-
self.model_to_device()
250-
251-
self.pre_configure_ddp()
252-
self.model = DistributedDataParallel(
253-
model,
254-
device_ids=self.determine_ddp_device_ids(),
255-
**self._ddp_kwargs,
317+
# todo: Currently DeepSpeed requires optimizers at inference to partition weights correctly
318+
optimizer, lightning_scheduler, optimizer_frequencies = None, None, None
319+
if "optimizer" not in self.config:
320+
rank_zero_info(
321+
"You have not specified an optimizer or scheduler within the DeepSpeed config."
322+
"Using `configure_optimizers` to define optimizer and scheduler."
323+
)
324+
optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer()
325+
inference_config = {
326+
# todo: this is required for DeepSpeed throughput timers, or throughput timers will be incorrect
327+
'train_micro_batch_size_per_gpu': 1,
328+
}
329+
if 'fp16' in self.config:
330+
inference_config.update({"fp16": self.config["fp16"]})
331+
if self.zero_stage_3:
332+
inference_config.update({
333+
"zero_allow_untested_optimizer": self.config['zero_allow_untested_optimizer'],
334+
"zero_optimization": self.config['zero_optimization'],
335+
})
336+
# Remove all module hooks before initializing new model
337+
remove_module_hooks(model)
338+
model, _, _, _ = deepspeed.initialize(
339+
args=SimpleNamespace(local_rank=self.local_rank),
340+
model=model,
341+
optimizer=optimizer,
342+
lr_scheduler=lightning_scheduler,
343+
config_params=inference_config,
344+
model_parameters=[],
256345
)
346+
self.model = model
257347

258348
def configure_scheduler(self, lr_scheduler):
259349
scheduler = _get_default_scheduler_config()
@@ -282,6 +372,20 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Calla
282372
# internally, the engine has a reference to the optimizer already.
283373
self.model.step(**kwargs)
284374

375+
def _handle_gradient_accumulation_steps(self):
376+
"""
377+
This functions overrides the trainer.accumulation_scheduler to generate
378+
``accumulate_grad_batches=1``.
379+
Therefore, ``optimizer_step`` will be called on every batches seen
380+
so DeepSpeed Engine handles the gradient accumulation logic internally.
381+
"""
382+
if self.config.get("gradient_accumulation_steps") > 1:
383+
self._original_accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches
384+
# todo (tchaton) Add support for accumulate_grad_batches being a dictionary.
385+
self.lightning_module.trainer.accumulation_scheduler = GradientAccumulationScheduler({0: 1})
386+
else:
387+
self._original_accumulate_grad_batches = None
388+
285389
def _format_config(self):
286390
if self.config is None:
287391
raise MisconfigurationException(
@@ -300,14 +404,13 @@ def _format_batch_size_and_grad_accum_config(self):
300404
if "train_micro_batch_size_per_gpu" not in self.config:
301405
# train_micro_batch_size_per_gpu is used for throughput logging purposes
302406
# by default we use the batch size of the loader which may be incorrect if a batch sampler is passed
303-
batch_size = self.lightning_module.train_dataloader().batch_size
407+
batch_size = self.lightning_module.train_dataloader().batch_sampler.batch_size
304408
self.config["train_micro_batch_size_per_gpu"] = batch_size
305409
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
306410
if "gradient_clipping" not in self.config:
307411
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val
308412

309413
def _format_precision_config(self):
310-
311414
amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
312415
amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
313416
precision = self.lightning_module.trainer.accelerator_connector.precision
@@ -333,8 +436,87 @@ def _format_precision_config(self):
333436
raise MisconfigurationException("To use DeepSpeed ZeRO Optimization, you must set precision=16.")
334437

335438
def _create_default_config(
336-
self, zero_optimization: bool, zero_allow_untested_optimizer: bool, **zero_kwargs
439+
self,
440+
zero_optimization: bool,
441+
zero_allow_untested_optimizer: bool,
442+
partition_activations: bool,
443+
cpu_checkpointing: bool,
444+
contiguous_memory_optimization: bool,
445+
synchronize_checkpoint_boundary: bool,
446+
**zero_kwargs,
337447
) -> Dict:
448+
cfg = {
449+
'activation_checkpointing': {
450+
"partition_activations": partition_activations,
451+
"cpu_checkpointing": cpu_checkpointing,
452+
"contiguous_memory_optimization": contiguous_memory_optimization,
453+
"synchronize_checkpoint_boundary": synchronize_checkpoint_boundary
454+
}
455+
}
338456
if zero_optimization:
339-
return {"zero_allow_untested_optimizer": zero_allow_untested_optimizer, "zero_optimization": zero_kwargs}
340-
return {}
457+
cfg = {
458+
"zero_allow_untested_optimizer": zero_allow_untested_optimizer,
459+
"zero_optimization": zero_kwargs,
460+
**cfg
461+
}
462+
return cfg
463+
464+
def _filepath_to_dir(self, filepath: str) -> str:
465+
return os.path.dirname(filepath)
466+
467+
@property
468+
def deepspeed_engine(self):
469+
return self.model
470+
471+
def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
472+
"""Save model/training states as a checkpoint file through state-dump and file-write.
473+
474+
Args:
475+
filepath: write-target file's path
476+
weights_only: saving model weights only
477+
"""
478+
if self.world_size > 1 and self.zero_stage_3:
479+
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
480+
# dump states as a checkpoint dictionary object
481+
save_dir = self._filepath_to_dir(filepath)
482+
_exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers']
483+
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
484+
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)
485+
486+
else:
487+
super().save_checkpoint(checkpoint, filepath)
488+
489+
def restore_model_state_from_ckpt_path(
490+
self,
491+
ckpt_path: str,
492+
map_location: Callable = lambda storage, loc: storage,
493+
) -> Tuple[Dict, bool]:
494+
if self.world_size > 1:
495+
from pytorch_lightning.trainer.states import TrainerState
496+
stage_is_fit = self.lightning_module.trainer.state == TrainerState.FITTING
497+
save_dir = self._filepath_to_dir(ckpt_path)
498+
499+
if self.zero_stage_3:
500+
# TODO: Currently required as this call is missing within the deepspeed engine.
501+
self.deepspeed_engine.optimizer._partition_all_parameters()
502+
503+
_, client_state = self.deepspeed_engine.load_checkpoint(
504+
save_dir, load_optimizer_states=stage_is_fit, load_lr_scheduler_states=stage_is_fit
505+
)
506+
507+
# restore datamodule states
508+
if self.lightning_module.trainer.datamodule is not None:
509+
self.lightning_module.trainer.datamodule.on_load_checkpoint(client_state)
510+
511+
# hook: give user access to checkpoint if needed.
512+
self.lightning_module.on_load_checkpoint(client_state)
513+
return client_state, False
514+
return super().restore_model_state_from_ckpt_path(ckpt_path, map_location=map_location)
515+
516+
def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
517+
if self._original_accumulate_grad_batches is None:
518+
return super().update_global_step(total_batch_idx, current_global_step)
519+
else:
520+
if total_batch_idx % self._original_accumulate_grad_batches == 0:
521+
current_global_step += 1
522+
return current_global_step

0 commit comments

Comments
 (0)