diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 05e8624b72630..5333bfd867da0 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -57,6 +57,7 @@ jobs: - bash: | python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" pip install fairscale>=0.3.4 + pip install deepspeed>=0.4.0 -U pip install . --requirement requirements/devel.txt pip list displayName: 'Install dependencies' diff --git a/CHANGELOG.md b/CHANGELOG.md index 294c11a99b70d..0829c7e069ff2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added LightningCLI support for argument links applied on instantiation ([#7895](https://github.com/PyTorchLightning/pytorch-lightning/pull/7895)) +- Added DeepSpeed Infinity Support, and updated to DeepSpeed 0.4.0 ([#7234](https://github.com/PyTorchLightning/pytorch-lightning/pull/7234)) + + - Added support for `torch.nn.UninitializedParameter` in `ModelSummary` ([#7642](https://github.com/PyTorchLightning/pytorch-lightning/pull/7642)) diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index e16971bdc2a1a..5c15e096cfb4b 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -118,8 +118,7 @@ RUN \ RUN \ # install DeepSpeed - # TODO(@SeanNaren): CI failing with `>=0.3.15` - skipping to unblock - pip install deepspeed==0.3.14 + pip install deepspeed==0.4.0 RUN \ # Show what we have diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index dc688de65cd34..8f613081cdfe2 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -15,9 +15,9 @@ import json import logging import os +import warnings from collections import OrderedDict from pathlib import Path -from types import SimpleNamespace from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch @@ -78,9 +78,23 @@ def __init__( self, zero_optimization: bool = True, stage: int = 2, - cpu_offload: bool = False, - cpu_offload_params: bool = False, - cpu_offload_use_pin_memory: bool = False, + remote_device: str = 'cpu', + offload_optimizer: bool = False, + offload_parameters: bool = False, + offload_params_device: str = 'cpu', + nvme_path: str = '/local_nvme', + params_buffer_count: int = 5, + params_buffer_size: int = 1e8, + max_in_cpu: int = 1e9, + offload_optimizer_device: str = 'cpu', + optimizer_buffer_count: int = 4, + block_size: int = 1048576, + queue_depth: int = 8, + single_submit: bool = False, + overlap_events: bool = True, + thread_count: int = 1, + pin_memory: bool = False, + sub_group_size: int = 1e12, contiguous_gradients: bool = True, overlap_comm: bool = True, allgather_partitions: bool = True, @@ -104,11 +118,14 @@ def __init__( contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, save_full_weights: bool = True, + cpu_offload: bool = False, + cpu_offload_params: bool = False, + cpu_offload_use_pin_memory: bool = False, ) -> None: """ Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. - `For more information: https://www.deepspeed.ai/`. + `For more information: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed`. .. warning:: ``DeepSpeedPlugin`` is in beta and subject to change. @@ -118,36 +135,81 @@ def __init__( Arguments: - zero_optimization: Enable ZeRO optimization. This is only compatible with precision=16. (default: True) + zero_optimization: Enable ZeRO optimization. This is only compatible with precision=16. stage: Different stages of the ZeRO Optimizer. 0 is disabled, - 1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning (default: 2) + 1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning, + 3 is optimizer+gradient_parameter partitioning using the infinity engine. + + remote_device: Device to instantiate the model on initially (``cpu`` or ``nvme``). + + offload_optimizer: Enable offloading optimizer memory and computation to CPU or NVMe + based on ``offload_optimizer_device``. + + offload_parameters: When using ZeRO Stage 3, Enable offloading parameter memory and computation + to CPU or NVMe based on ``offload_params_device``. + + offload_params_device: When offloading parameters choose the device to offload to, ``cpu`` or ``nvme``. + + offload_optimizer_device: When offloading optimizer state choose the device to offload to, + ``cpu`` or ``nvme``. + + params_buffer_count: Number of buffers in buffer pool for + parameter offloading when ``offload_params_device`` is ``nvme``. + + params_buffer_size: Size of buffers in buffer pool for parameter offloading + when ``offload_params_device`` is ``nvme``. - cpu_offload: Enable offloading optimizer memory and computation to CPU + max_in_cpu: Number of parameter elements to maintain in CPU memory when offloading to NVMe is enabled. - cpu_offload_params: When using ZeRO stage 3, offload parameters to CPU + nvme_path: Filesystem path for NVMe device for optimizer/parameter state offloading. - cpu_offload_use_pin_memory: When using ZeRO stage 3, pin memory on CPU + optimizer_buffer_count: Number of buffers in buffer pool for optimizer state offloading + when ``offload_optimizer_device`` is set to to ``nvme``. + This should be at least the number of states maintained per parameter by the optimizer. + For example, Adam optimizer has 4 states (parameter, gradient, momentum, and variance). + + block_size: When using NVMe Offloading, the I/O block size in bytes. + + queue_depth: When using NVMe Offloading, the I/O queue depth. + + single_submit: When using NVMe Offloading, + submit requests to storage device as multiple individual requests, + as opposed to one block of requests. + + overlap_events: When using NVMe Offloading, + submit requests to storage device in an overlapped fashion + without waiting for completion of earlier requests. + + thread_count: When using NVMe Offloading, + Intra-request parallelism for each read/write submitted by a user thread. + + pin_memory: When using ZeRO stage 3, pin optimizer state memory on CPU. + This could boost throughput at the cost of extra memory overhead. + + sub_group_size: When using ZeRO stage 3, defines the number of parameters + within a sub group to offload at a time. + Smaller numbers require more communication, but improve memory efficiency. contiguous_gradients: Copies gradients to a continuous buffer as they are produced. - Avoids memory fragmentation during backwards. Useful when training large models. (default: True) + Avoids memory fragmentation during backwards. Useful when training large models. overlap_comm: Overlap the reduction (synchronization) of gradients with the backwards computation. - This is a speed optimization when training across multiple GPUs/machines. (default: True) + This is a speed optimization when training across multiple GPUs/machines. allgather_partitions: All gather updated parameters at the end of training step, - instead of using a series of broadcast collectives (default: True) + instead of using a series of broadcast collectives. - reduce_scatter: Use reduce/scatter instead of allreduce to average gradients (default:True) + reduce_scatter: Use reduce/scatter instead of allreduce to average gradients. allgather_bucket_size: Number of elements to allgather at once. - Used to limit the memory required for larger model sizes, with a tradeoff with speed. (default: 2e8) + Used to limit the memory required for larger model sizes, with a tradeoff with speed. reduce_bucket_size: Number of elements to reduce at once. - Used to limit the memory required for larger model sizes, with a tradeoff with speed (default: 2e8) + Used to limit the memory required for larger model sizes, with a tradeoff with speed. zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a - DeepSpeed supported optimizer when using ZeRO (default: True) + DeepSpeed supported optimizer when using ZeRO. logging_batch_size_per_gpu: Config used in DeepSpeed to calculate verbose timing for logging on a per sample per second basis (only displayed if logging=logging.INFO). @@ -158,45 +220,55 @@ def __init__( config: Pass in a deepspeed formatted config dict, or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. - All defaults will be ignored if a config is passed in. (Default: ``None``) + All defaults will be ignored if a config is passed in. - logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``) + logging_level: Set logging level for deepspeed. loss_scale: Loss scaling value for FP16 training. - 0.0 results in dynamic loss scaling, otherwise static (Default: 0) + 0.0 results in dynamic loss scaling, otherwise static. initial_scale_power: Power of the initial dynamic loss scale value. Loss scale is computed - by ``2^initial_scale_power`` (Default: 32) + by ``2^initial_scale_power``. - loss_scale_window: Window in which to raise/lower the dynamic FP16 loss scaling value (Default: 1000) + loss_scale_window: Window in which to raise/lower the dynamic FP16 loss scaling value. - hysteresis: FP16 Delay shift in Dynamic Loss scaling (Default: 2) + hysteresis: FP16 Delay shift in Dynamic Loss scaling. - min_loss_scale: The minimum FP16 dynamic loss scaling value (Default: 1000) + min_loss_scale: The minimum FP16 dynamic loss scaling value. - partition_activations: Enables partition activation when used with ZeRO stage 3. + partition_activations: Enables partition activation when used with ZeRO stage 3 and model parallelism. Still requires you to wrap your forward functions in deepspeed.checkpointing.checkpoint. See `deepspeed tutorial - `_ + `_. - cpu_checkpointing: Offloads partitioned activations to CPU if ``partition_activations`` is enabled + cpu_checkpointing: Offloads partitioned activations to CPU if ``partition_activations`` is enabled. contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory. - Not supported by all models + Not supported by all models. synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary. save_full_weights: Gathers weights across all processes before saving to disk when using ZeRO Stage 3. This allows a single weight file to contain the entire model, rather than individual sharded weight files. - Disable to save sharded states individually. (Default: True) - + Disable to save sharded states individually. """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( "To use the DeepSpeed plugin, you must have DeepSpeed installed." " pip install deepspeed" ) + + if cpu_offload or cpu_offload_params or cpu_offload_use_pin_memory: + warnings.warn( + "The usage of `cpu_offload`, `cpu_offload_params`, and `cpu_offload_use_pin_memory` " + "is deprecated since v1.4 and will be removed in v1.5." + " From now on use `offload_optimizer`, `offload_parameters` and `pin_memory`.", DeprecationWarning + ) + offload_optimizer = cpu_offload + offload_parameters = cpu_offload_params + pin_memory = cpu_offload_use_pin_memory + super().__init__( parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment ) @@ -207,24 +279,38 @@ def __init__( zero_optimization, zero_allow_untested_optimizer, logging_batch_size_per_gpu, + offload_optimizer=offload_optimizer, + offload_parameters=offload_parameters, + nvme_path=nvme_path, + offload_params_device=offload_params_device, + params_buffer_count=params_buffer_count, + params_buffer_size=params_buffer_size, + max_in_cpu=max_in_cpu, + pin_memory=pin_memory, + offload_optimizer_device=offload_optimizer_device, + optimizer_buffer_count=optimizer_buffer_count, + block_size=block_size, + queue_depth=queue_depth, + single_submit=single_submit, + overlap_events=overlap_events, + thread_count=thread_count, partition_activations=partition_activations, cpu_checkpointing=cpu_checkpointing, contiguous_memory_optimization=contiguous_memory_optimization, synchronize_checkpoint_boundary=synchronize_checkpoint_boundary, stage=stage, - cpu_offload=cpu_offload, - cpu_offload_params=cpu_offload_params, - cpu_offload_use_pin_memory=cpu_offload_use_pin_memory, contiguous_gradients=contiguous_gradients, overlap_comm=overlap_comm, allgather_partitions=allgather_partitions, reduce_scatter=reduce_scatter, allgather_bucket_size=allgather_bucket_size, reduce_bucket_size=reduce_bucket_size, + sub_group_size=sub_group_size, ) self._config_initialized = False deepspeed.utils.logging.logger.setLevel(logging_level) + self.remote_device = remote_device self.save_full_weights = save_full_weights # default FP16 parameters. @@ -247,22 +333,30 @@ def _load_config(self, config): config = json.load(f) return config + def setup_distributed(self): + super().setup_distributed() + if not self._config_initialized: + self._format_config() + self._config_initialized = True + if self.on_gpu: + torch.cuda.set_device(self.root_device) + def pre_dispatch(self): self.init_deepspeed() self.barrier() def init_deepspeed(self): - if not self._config_initialized: - self._format_config() - self._config_initialized = True - self._handle_gradient_accumulation_steps() precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) - if self.on_gpu: - torch.cuda.set_device(self.root_device) + if self.zero_stage_3: + # Ensure the entire model has been moved to the appropriate device + dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 + deepspeed.zero.Init( + module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype + ) if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) @@ -287,6 +381,7 @@ def zero_stage_3(self) -> bool: def _initialize_deepspeed_train(self, model): optimizer, lightning_scheduler, optimizer_frequencies = None, None, None + if "optimizer" not in self.config: rank_zero_info( "You have not specified an optimizer or scheduler within the DeepSpeed config." @@ -295,12 +390,12 @@ def _initialize_deepspeed_train(self, model): optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer() model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) model, optimizer, _, lr_scheduler = deepspeed.initialize( - args=SimpleNamespace(local_rank=self.local_rank), + config=self.config, model=model, model_parameters=model_parameters, optimizer=optimizer, lr_scheduler=lightning_scheduler, - config_params=self.config, + dist_init_required=False ) self._set_deepspeed_activation_checkpointing() @@ -312,13 +407,21 @@ def _initialize_deepspeed_train(self, model): @contextlib.contextmanager def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: - model_parallel_context = deepspeed.zero.Init(remote_device="cpu", pin_memory=True) + assert self._config_initialized + dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 + model_parallel_context = deepspeed.zero.Init( + remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype + ) else: model_parallel_context = super().model_sharded_context() with model_parallel_context: yield + @property + def precision(self) -> Union[str, int]: + return self.lightning_module.trainer.precision + def _set_deepspeed_activation_checkpointing(self): if self.config.get('activation_checkpointing'): checkpoint_config = self.config['activation_checkpointing'] @@ -353,12 +456,12 @@ def _initialize_deepspeed_inference(self, model): # Remove all module hooks before initializing new model remove_module_hooks(model) model, _, _, _ = deepspeed.initialize( - args=SimpleNamespace(local_rank=self.local_rank), + config=inference_config, model=model, optimizer=optimizer, lr_scheduler=lightning_scheduler, - config_params=inference_config, model_parameters=[], + dist_init_required=False ) self.model = model @@ -469,6 +572,21 @@ def _create_default_config( cpu_checkpointing: bool, contiguous_memory_optimization: bool, synchronize_checkpoint_boundary: bool, + offload_optimizer: bool, + offload_parameters: bool, + nvme_path: str, + offload_params_device: str, + params_buffer_count: int, + params_buffer_size: int, + max_in_cpu: int, + offload_optimizer_device: str, + optimizer_buffer_count: int, + pin_memory: bool, + block_size: int, + queue_depth: int, + single_submit: bool, + overlap_events: bool, + thread_count: int, **zero_kwargs, ) -> Dict: cfg = { @@ -477,12 +595,37 @@ def _create_default_config( "cpu_checkpointing": cpu_checkpointing, "contiguous_memory_optimization": contiguous_memory_optimization, "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary - } + }, + "aio": { + "block_size": block_size, + "queue_depth": queue_depth, + "single_submit": single_submit, + "overlap_events": overlap_events, + "thread_count": thread_count + }, } if zero_optimization: + zero_config = zero_kwargs + + if offload_optimizer: + zero_config["offload_optimizer"] = { + 'device': offload_optimizer_device, + 'nvme_path': nvme_path, + 'buffer_count': optimizer_buffer_count, + 'pin_memory': pin_memory + } + if offload_parameters: + zero_config['offload_param'] = { + 'device': offload_params_device, + 'nvme_path': nvme_path, + 'buffer_count': params_buffer_count, + 'buffer_size': params_buffer_size, + 'max_in_cpu': max_in_cpu, + 'pin_memory': pin_memory + } cfg = { "zero_allow_untested_optimizer": zero_allow_untested_optimizer, - "zero_optimization": zero_kwargs, + "zero_optimization": zero_config, **cfg } if logging_batch_size_per_gpu != 'auto': @@ -570,7 +713,7 @@ def register_plugins(cls, plugin_registry: Dict) -> None: cls, description="DeepSpeed ZeRO Stage 2 and CPU Offload", stage=2, - cpu_offload=True + offload_optimizer=True ) plugin_registry.register("deepspeed_stage_3", cls, description="DeepSpeed ZeRO Stage 3", stage=3) plugin_registry.register( @@ -578,5 +721,17 @@ def register_plugins(cls, plugin_registry: Dict) -> None: cls, description="DeepSpeed ZeRO Stage 3 and CPU Offload", stage=3, - cpu_offload=True + offload_optimizer=True, + offload_parameters=True, + ) + plugin_registry.register( + "deepspeed_stage_3_offload_nvme", + cls, + description="DeepSpeed ZeRO Stage 3 and NVMe Offload", + stage=3, + offload_optimizer=True, + offload_parameters=True, + remote_device='nvme', + offload_params_device='nvme', + offload_optimizer_device='nvme' ) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index d6c9b6d8f8f31..513650a945e95 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -25,12 +25,14 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.decorators import auto_move_data from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.plugins import DeepSpeedPlugin from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.imports import _compare_version from tests.deprecated_api import no_deprecated_call from tests.helpers import BoringDataModule, BoringModel +from tests.helpers.runif import RunIf from tests.helpers.utils import no_warning_call @@ -374,3 +376,15 @@ def test_v1_5_0_datamodule_setter(): def test_v1_5_0_trainer_tbptt_steps(tmpdir): with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): _ = Trainer(truncated_bptt_steps=1) + + +@RunIf(deepspeed=True) +@pytest.mark.parametrize( + "params", [dict(cpu_offload=True), + dict(cpu_offload_params=True), + dict(cpu_offload_use_pin_memory=True)] +) +def test_v1_5_0_deepspeed_cpu_offload(tmpdir, params): + + with pytest.deprecated_call(match="is deprecated since v1.4 and will be removed in v1.5"): + DeepSpeedPlugin(**params) diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py index 8ccba40013517..12845ed47d901 100644 --- a/tests/plugins/test_plugins_registry.py +++ b/tests/plugins/test_plugins_registry.py @@ -54,14 +54,15 @@ def __init__(self, param1, param2): }), ("deepspeed_stage_2_offload", { "stage": 2, - "cpu_offload": True + "offload_optimizer": True }), ("deepspeed_stage_3", { "stage": 3 }), ("deepspeed_stage_3_offload", { "stage": 3, - "cpu_offload": True + "offload_parameters": True, + "offload_optimizer": True }), ], )