File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed
tests/tests_fabric/strategies Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change 2626from lightning_utilities .core .imports import RequirementCache
2727from torch .optim import Adam
2828
29+ import lightning
2930from lightning .fabric import Fabric
3031from lightning .fabric .plugins .environments import LightningEnvironment
3132from lightning .fabric .strategies import FSDPStrategy
@@ -125,12 +126,16 @@ def test_fsdp_no_backward_sync():
125126
126127
127128@RunIf (min_torch = "1.12" )
128- @mock .patch ("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_1_13" , False )
129- def test_fsdp_activation_checkpointing_support ():
129+ def test_fsdp_activation_checkpointing_support (monkeypatch ):
130130 """Test that we error out if activation checkpointing requires a newer PyTorch version."""
131+ monkeypatch .setattr (lightning .fabric .strategies .fsdp , "_TORCH_GREATER_EQUAL_1_13" , False )
131132 with pytest .raises (ValueError , match = "activation_checkpointing` requires torch >= 1.13.0" ):
132133 FSDPStrategy (activation_checkpointing = Mock ())
133134
135+ monkeypatch .setattr (lightning .fabric .strategies .fsdp , "_TORCH_GREATER_EQUAL_2_1" , False )
136+ with pytest .raises (ValueError , match = "activation_checkpointing_policy` requires torch >= 2.1.0" ):
137+ FSDPStrategy (activation_checkpointing_policy = Mock ())
138+
134139
135140@RunIf (min_torch = "1.13" )
136141def test_fsdp_activation_checkpointing ():
You can’t perform that action at this time.
0 commit comments