|
19 | 19 | from pytorch_lightning.strategies import DeepSpeedStrategy |
20 | 20 | from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule |
21 | 21 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
22 | | -from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE |
| 22 | +from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE, _DEEPSPEED_GREATER_EQUAL_0_5_9 |
23 | 23 | from pytorch_lightning.utilities.meta import init_meta_context |
24 | 24 | from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset |
25 | 25 | from tests.helpers.datamodules import ClassifDataModule |
|
29 | 29 | import deepspeed |
30 | 30 | from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict |
31 | 31 |
|
| 32 | + if _DEEPSPEED_GREATER_EQUAL_0_5_9: |
| 33 | + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer |
| 34 | + else: |
| 35 | + from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer as DeepSpeedZeroOptimizer |
| 36 | + |
32 | 37 |
|
33 | 38 | class ModelParallelBoringModel(BoringModel): |
34 | 39 | def __init__(self): |
@@ -280,9 +285,7 @@ def test_deepspeed_run_configure_optimizers(tmpdir): |
280 | 285 |
|
281 | 286 | class TestCB(Callback): |
282 | 287 | def on_train_start(self, trainer, pl_module) -> None: |
283 | | - from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer |
284 | | - |
285 | | - assert isinstance(trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) |
| 288 | + assert isinstance(trainer.optimizers[0], DeepSpeedZeroOptimizer) |
286 | 289 | assert isinstance(trainer.optimizers[0].optimizer, torch.optim.SGD) |
287 | 290 | assert isinstance(trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.StepLR) |
288 | 291 | # check that the lr_scheduler config was preserved |
@@ -319,9 +322,8 @@ def test_deepspeed_config(tmpdir, deepspeed_zero_config): |
319 | 322 | class TestCB(Callback): |
320 | 323 | def on_train_start(self, trainer, pl_module) -> None: |
321 | 324 | from deepspeed.runtime.lr_schedules import WarmupLR |
322 | | - from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer |
323 | 325 |
|
324 | | - assert isinstance(trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) |
| 326 | + assert isinstance(trainer.optimizers[0], DeepSpeedZeroOptimizer) |
325 | 327 | assert isinstance(trainer.optimizers[0].optimizer, torch.optim.SGD) |
326 | 328 | assert isinstance(trainer.lr_scheduler_configs[0].scheduler, WarmupLR) |
327 | 329 | assert trainer.lr_scheduler_configs[0].interval == "step" |
|
0 commit comments