Skip to content

Commit b46cc55

Browse files
author
Sean Naren
authored
[Feat] DeepSpeed single file saving (#6900)
* Add single checkpoint capability * Fix checkpointing in test, few cleanups * Add comment * Change restore logic * Move vars around, add better explanation, make todo align with DeepSpeed team * Fix checkpointing * Remove deepspeed from extra, install in Dockerfile * push * pull * Split to two tests to see if it fixes Deepspeed error * Add comment
1 parent e9c3e02 commit b46cc55

File tree

4 files changed

+58
-17
lines changed

4 files changed

+58
-17
lines changed

dockers/base-cuda/Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ RUN \
113113
pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex && \
114114
rm -rf apex
115115

116+
RUN \
117+
# install DeepSpeed
118+
pip install deepspeed>=0.3.14
119+
116120
RUN \
117121
# Show what we have
118122
pip --version && \

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
cpu_checkpointing: bool = False,
103103
contiguous_memory_optimization: bool = False,
104104
synchronize_checkpoint_boundary: bool = False,
105+
save_full_weights: bool = True,
105106
) -> None:
106107
"""
107108
@@ -177,11 +178,16 @@ def __init__(
177178
Not supported by all models
178179
179180
synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.
181+
182+
save_full_weights: Gathers weights across all processes before saving to disk
183+
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
184+
rather than individual sharded weight files.
185+
Disable to save sharded states individually. (Default: True)
180186
"""
181187
if not _DEEPSPEED_AVAILABLE:
182188
raise MisconfigurationException(
183189
"To use the DeepSpeed plugin, you must have DeepSpeed installed."
184-
" pip install deepspeed mpi4py"
190+
" pip install deepspeed"
185191
)
186192
super().__init__(
187193
parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment
@@ -205,11 +211,13 @@ def __init__(
205211
allgather_partitions=allgather_partitions,
206212
reduce_scatter=reduce_scatter,
207213
allgather_bucket_size=allgather_bucket_size,
208-
reduce_bucket_size=reduce_bucket_size
214+
reduce_bucket_size=reduce_bucket_size,
209215
)
210216
self._config_initialized = False
211217
deepspeed.utils.logging.logger.setLevel(logging_level)
212218

219+
self.save_full_weights = save_full_weights
220+
213221
# default FP16 parameters.
214222
self.loss_scale = loss_scale
215223
self.initial_scale_power = initial_scale_power
@@ -472,17 +480,27 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
472480
"""Save model/training states as a checkpoint file through state-dump and file-write.
473481
474482
Args:
483+
checkpoint: The checkpoint state dictionary
475484
filepath: write-target file's path
476-
weights_only: saving model weights only
477485
"""
478486
if self.world_size > 1 and self.zero_stage_3:
487+
if self.save_full_weights:
488+
# todo: expose this as general function in deepspeed
489+
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
490+
if self.is_global_zero:
491+
# State dict keys will include reference to wrapper LightningDeepSpeedModule
492+
# Delete `module` prefix before saving.
493+
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
494+
checkpoint['state_dict'] = state_dict
495+
return super().save_checkpoint(checkpoint, filepath)
496+
return
497+
479498
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
480499
# dump states as a checkpoint dictionary object
481500
save_dir = self._filepath_to_dir(filepath)
482501
_exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers']
483502
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
484503
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)
485-
486504
else:
487505
super().save_checkpoint(checkpoint, filepath)
488506

@@ -491,7 +509,8 @@ def restore_model_state_from_ckpt_path(
491509
ckpt_path: str,
492510
map_location: Callable = lambda storage, loc: storage,
493511
) -> Tuple[Dict, bool]:
494-
if self.world_size > 1:
512+
if not self.save_full_weights and self.world_size > 1:
513+
# Rely on deepspeed to load the checkpoint and necessary information
495514
from pytorch_lightning.trainer.states import TrainerState
496515
stage_is_fit = self.lightning_module.trainer.state == TrainerState.FITTING
497516
save_dir = self._filepath_to_dir(ckpt_path)
@@ -511,6 +530,10 @@ def restore_model_state_from_ckpt_path(
511530
# hook: give user access to checkpoint if needed.
512531
self.lightning_module.on_load_checkpoint(client_state)
513532
return client_state, False
533+
534+
# Broadcast to ensure we load from the rank 0 checkpoint
535+
# This doesn't have to be the case when using deepspeed sharded checkpointing
536+
ckpt_path = self.broadcast(ckpt_path)
514537
return super().restore_model_state_from_ckpt_path(ckpt_path, map_location=map_location)
515538

516539
def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:

requirements/extra.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,3 @@ hydra-core>=1.0
1010
# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs
1111
https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip
1212
jsonargparse[signatures]>=3.3.1
13-
deepspeed>=0.3.13

tests/plugins/test_deepspeed_plugin.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
from typing import Any
3+
from typing import Any, Dict
44

55
import pytest
66
import torch
@@ -28,6 +28,9 @@ def __init__(self):
2828
def configure_sharded_model(self) -> None:
2929
self.linear = torch.nn.Linear(32, 2)
3030

31+
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
32+
self.configure_sharded_model()
33+
3134

3235
def test_deepspeed_lightning_module(tmpdir):
3336
"""
@@ -456,23 +459,17 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config):
456459
trainer.fit(model)
457460
trainer.test(model)
458461

459-
# todo (tchaton) Currently load_from_checkpoint is not support for zero-v3
460-
# _assert_save_model_is_equal(model, tmpdir, trainer)
462+
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModel)
461463

462464

463-
@RunIf(min_gpus=2, deepspeed=True, special=True)
464-
def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
465-
"""
466-
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
467-
and see convergence.
468-
"""
465+
def run_checkpoint_test(tmpdir, save_full_weights):
469466
seed_everything(42)
470467
model = ModelParallelClassificationModel()
471468
dm = ClassifDataModule()
472469
ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1)
473470
trainer = Trainer(
474471
max_epochs=10,
475-
plugins=[DeepSpeedPlugin(stage=3)],
472+
plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
476473
default_root_dir=tmpdir,
477474
gpus=2,
478475
precision=16,
@@ -490,7 +487,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
490487

491488
trainer = Trainer(
492489
max_epochs=10,
493-
plugins=[DeepSpeedPlugin(stage=3)],
490+
plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
494491
default_root_dir=tmpdir,
495492
gpus=2,
496493
precision=16,
@@ -506,6 +503,24 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
506503
assert results[-1] > 0.7
507504

508505

506+
@RunIf(min_gpus=2, deepspeed=True, special=True)
507+
def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
508+
"""
509+
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
510+
and see convergence.
511+
"""
512+
run_checkpoint_test(tmpdir, save_full_weights=False)
513+
514+
515+
@RunIf(min_gpus=2, deepspeed=True, special=True)
516+
def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir):
517+
"""
518+
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
519+
where we save the full weights to one file.
520+
"""
521+
run_checkpoint_test(tmpdir, save_full_weights=True)
522+
523+
509524
@RunIf(min_gpus=2, deepspeed=True, special=True)
510525
@pytest.mark.parametrize('cpu_offload', [True, False])
511526
def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, cpu_offload):

0 commit comments

Comments
 (0)