99from pytorch_lightning .callbacks import ModelCheckpoint
1010from pytorch_lightning .strategies import DDPFullyShardedNativeStrategy
1111from pytorch_lightning .utilities .exceptions import MisconfigurationException
12- from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_11
12+ from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_12
1313from tests .helpers .boring_model import BoringModel
1414from tests .helpers .runif import RunIf
1515
16- if _TORCH_GREATER_EQUAL_1_11 :
16+ if _TORCH_GREATER_EQUAL_1_12 :
1717 from torch .distributed .fsdp .fully_sharded_data_parallel import FullyShardedDataParallel
1818 from torch .distributed .fsdp .wrap import wrap
1919
2020
21- @RunIf (min_torch = "1.11 " )
21+ @RunIf (min_torch = "1.12dev " )
2222def test_invalid_on_cpu (tmpdir ):
2323 """Test to ensure that to raise Misconfiguration for Native FSDP on CPU."""
2424 with pytest .raises (
@@ -34,7 +34,7 @@ def test_invalid_on_cpu(tmpdir):
3434@mock .patch .dict (os .environ , {"CUDA_VISIBLE_DEVICES" : "0" })
3535@mock .patch ("torch.cuda.device_count" , return_value = 1 )
3636@mock .patch ("torch.cuda.is_available" , return_value = True )
37- @RunIf (min_torch = "1.11 " )
37+ @RunIf (min_torch = "1.12dev " )
3838def test_fsdp_with_sharded_amp (device_count_mock , mock_cuda_available , tmpdir ):
3939 """Test to ensure that plugin native amp plugin raises Misconfiguration error."""
4040 with pytest .raises (
@@ -102,7 +102,7 @@ def _assert_layer_fsdp_instance(self) -> None:
102102 assert self .layer .module [2 ].reshard_after_forward is True
103103
104104
105- @RunIf (min_gpus = 2 , skip_windows = True , standalone = True , min_torch = "1.11 " )
105+ @RunIf (min_gpus = 2 , skip_windows = True , standalone = True , min_torch = "1.12dev " )
106106def test_fully_sharded_native_strategy_sync_batchnorm (tmpdir ):
107107 """Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run."""
108108
@@ -119,7 +119,7 @@ def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
119119 _run_multiple_stages (trainer , model , os .path .join (tmpdir , "last.ckpt" ))
120120
121121
122- @RunIf (min_gpus = 1 , skip_windows = True , standalone = True , min_torch = "1.11 " )
122+ @RunIf (min_gpus = 1 , skip_windows = True , standalone = True , min_torch = "1.12dev " )
123123def test_fully_sharded_native_strategy_checkpoint (tmpdir ):
124124 """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
125125
@@ -130,7 +130,7 @@ def test_fully_sharded_native_strategy_checkpoint(tmpdir):
130130 _run_multiple_stages (trainer , model , os .path .join (tmpdir , "last.ckpt" ))
131131
132132
133- @RunIf (min_gpus = 2 , skip_windows = True , standalone = True , min_torch = "1.11 " )
133+ @RunIf (min_gpus = 2 , skip_windows = True , standalone = True , min_torch = "1.12dev " )
134134def test_fully_sharded_native_strategy_checkpoint_multi_gpus (tmpdir ):
135135 """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
136136
0 commit comments