|
30 | 30 | from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint |
31 | 31 | from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset |
32 | 32 | from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin |
33 | | -from pytorch_lightning.plugins.precision.deepspeed import _DEEPSPEED_GREATER_EQUAL_0_6 |
34 | 33 | from pytorch_lightning.strategies import DeepSpeedStrategy |
35 | 34 | from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule |
36 | 35 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
37 | | -from pytorch_lightning.utilities.imports import _RequirementAvailable |
38 | 36 | from pytorch_lightning.utilities.meta import init_meta_context |
39 | 37 | from tests_pytorch.helpers.datamodules import ClassifDataModule |
40 | 38 | from tests_pytorch.helpers.datasets import RandomIterableDataset |
41 | 39 | from tests_pytorch.helpers.runif import RunIf |
42 | 40 |
|
43 | 41 | if _DEEPSPEED_AVAILABLE: |
44 | 42 | import deepspeed |
| 43 | + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer |
45 | 44 | from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict |
46 | 45 |
|
47 | | - _DEEPSPEED_GREATER_EQUAL_0_5_9 = _RequirementAvailable("deepspeed>=0.5.9") |
48 | | - if _DEEPSPEED_GREATER_EQUAL_0_5_9: |
49 | | - from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer |
50 | | - else: |
51 | | - from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer as DeepSpeedZeroOptimizer |
52 | | - |
53 | 46 |
|
54 | 47 | class ModelParallelBoringModel(BoringModel): |
55 | 48 | def __init__(self): |
@@ -1294,7 +1287,6 @@ def training_step(self, *args, **kwargs): |
1294 | 1287 |
|
1295 | 1288 |
|
1296 | 1289 | @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) |
1297 | | -@pytest.mark.skipif(not _DEEPSPEED_GREATER_EQUAL_0_6, reason="requires deepspeed >= 0.6") |
1298 | 1290 | def test_deepspeed_with_bfloat16_precision(tmpdir): |
1299 | 1291 | """Test that deepspeed works with bfloat16 precision.""" |
1300 | 1292 | model = BoringModel() |
|
0 commit comments