diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 0a2465b85c484..68206ad8e980c 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -52,8 +52,8 @@ jobs: - bash: | python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" - pip install fairscale>=0.4.5 - pip install deepspeed>=0.6.0 + pip install "fairscale>=0.4.5" + pip install "deepspeed<0.6.0" # https://github.com/microsoft/DeepSpeed/issues/1878 CUDA_VERSION_MM=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))") pip install "bagua-cuda$CUDA_VERSION_MM>=0.9.0" pip install . --requirement requirements/devel.txt diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 232c51ad636ee..569e3530201d2 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -98,6 +98,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: _APEX_AVAILABLE = _module_available("apex.amp") _BAGUA_AVAILABLE = _package_available("bagua") _DEEPSPEED_AVAILABLE = _package_available("deepspeed") +_DEEPSPEED_GREATER_EQUAL_0_5_9 = _DEEPSPEED_AVAILABLE and _compare_version("deepspeed", operator.ge, "0.5.9") _DEEPSPEED_GREATER_EQUAL_0_6 = _DEEPSPEED_AVAILABLE and _compare_version("deepspeed", operator.ge, "0.6.0") _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3") diff --git a/tests/strategies/test_deepspeed_strategy.py b/tests/strategies/test_deepspeed_strategy.py index e2be98b970967..319289d200f4f 100644 --- a/tests/strategies/test_deepspeed_strategy.py +++ b/tests/strategies/test_deepspeed_strategy.py @@ -32,7 +32,11 @@ from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE, _DEEPSPEED_GREATER_EQUAL_0_6 +from pytorch_lightning.utilities.imports import ( + _DEEPSPEED_AVAILABLE, + _DEEPSPEED_GREATER_EQUAL_0_5_9, + _DEEPSPEED_GREATER_EQUAL_0_6, +) from pytorch_lightning.utilities.meta import init_meta_context from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.datamodules import ClassifDataModule @@ -42,6 +46,11 @@ import deepspeed from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict + if _DEEPSPEED_GREATER_EQUAL_0_5_9: + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + else: + from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer as DeepSpeedZeroOptimizer + class ModelParallelBoringModel(BoringModel): def __init__(self): @@ -296,9 +305,7 @@ def test_deepspeed_run_configure_optimizers(tmpdir): class TestCB(Callback): def on_train_start(self, trainer, pl_module) -> None: - from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer - - assert isinstance(trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) + assert isinstance(trainer.optimizers[0], DeepSpeedZeroOptimizer) assert isinstance(trainer.optimizers[0].optimizer, torch.optim.SGD) assert isinstance(trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.StepLR) # check that the lr_scheduler config was preserved @@ -337,9 +344,8 @@ def test_deepspeed_config(tmpdir, deepspeed_zero_config): class TestCB(Callback): def on_train_start(self, trainer, pl_module) -> None: from deepspeed.runtime.lr_schedules import WarmupLR - from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer - assert isinstance(trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) + assert isinstance(trainer.optimizers[0], DeepSpeedZeroOptimizer) assert isinstance(trainer.optimizers[0].optimizer, torch.optim.SGD) assert isinstance(trainer.lr_scheduler_configs[0].scheduler, WarmupLR) assert trainer.lr_scheduler_configs[0].interval == "step"