Skip to content

Commit cf3d995

Browse files
committed
Policy test
1 parent dd86ce0 commit cf3d995

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/tests_fabric/strategies/test_fsdp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from lightning_utilities.core.imports import RequirementCache
2727
from torch.optim import Adam
2828

29+
import lightning
2930
from lightning.fabric import Fabric
3031
from lightning.fabric.plugins.environments import LightningEnvironment
3132
from lightning.fabric.strategies import FSDPStrategy
@@ -125,12 +126,16 @@ def test_fsdp_no_backward_sync():
125126

126127

127128
@RunIf(min_torch="1.12")
128-
@mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False)
129-
def test_fsdp_activation_checkpointing_support():
129+
def test_fsdp_activation_checkpointing_support(monkeypatch):
130130
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
131+
monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_1_13", False)
131132
with pytest.raises(ValueError, match="activation_checkpointing` requires torch >= 1.13.0"):
132133
FSDPStrategy(activation_checkpointing=Mock())
133134

135+
monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", False)
136+
with pytest.raises(ValueError, match="activation_checkpointing_policy` requires torch >= 2.1.0"):
137+
FSDPStrategy(activation_checkpointing_policy=Mock())
138+
134139

135140
@RunIf(min_torch="1.13")
136141
def test_fsdp_activation_checkpointing():

0 commit comments

Comments
 (0)