Skip to content

Commit 86aa452

Browse files
committed
test
1 parent d319270 commit 86aa452

File tree

1 file changed

+53
-27
lines changed

1 file changed

+53
-27
lines changed

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -294,35 +294,35 @@ def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params):
294294
assert all(_ex == _co for _ex, _co in zip(full_state_dict.keys(), correct_state_dict.keys()))
295295

296296

297-
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
297+
# @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
298298
@pytest.mark.parametrize(
299299
("model", "strategy", "strategy_cfg"),
300300
[
301301
pytest.param(TestFSDPModel(), "fsdp", None, id="manually_wrapped"),
302-
pytest.param(
303-
TestFSDPModelAutoWrapped(),
304-
FSDPStrategy,
305-
{"auto_wrap_policy": custom_auto_wrap_policy},
306-
marks=RunIf(max_torch="2.0.0"),
307-
id="autowrap_1x",
308-
),
309-
pytest.param(
310-
TestFSDPModelAutoWrapped(),
311-
FSDPStrategy,
312-
{"auto_wrap_policy": custom_auto_wrap_policy},
313-
marks=RunIf(min_torch="2.0.0"),
314-
id="autowrap_2x",
315-
),
316-
pytest.param(
317-
TestFSDPModelAutoWrapped(),
318-
FSDPStrategy,
319-
{
320-
"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}) if _TORCH_GREATER_EQUAL_2_1 else None,
321-
"use_orig_params": True,
322-
},
323-
marks=RunIf(min_torch="2.1.0"),
324-
id="autowrap_use_orig_params",
325-
),
302+
# pytest.param(
303+
# TestFSDPModelAutoWrapped(),
304+
# FSDPStrategy,
305+
# {"auto_wrap_policy": custom_auto_wrap_policy},
306+
# marks=RunIf(max_torch="2.0.0"),
307+
# id="autowrap_1x",
308+
# ),
309+
# pytest.param(
310+
# TestFSDPModelAutoWrapped(),
311+
# FSDPStrategy,
312+
# {"auto_wrap_policy": custom_auto_wrap_policy},
313+
# marks=RunIf(min_torch="2.0.0"),
314+
# id="autowrap_2x",
315+
# ),
316+
# pytest.param(
317+
# TestFSDPModelAutoWrapped(),
318+
# FSDPStrategy,
319+
# {
320+
# "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}) if _TORCH_GREATER_EQUAL_2_1 else None,
321+
# "use_orig_params": True,
322+
# },
323+
# marks=RunIf(min_torch="2.1.0"),
324+
# id="autowrap_use_orig_params",
325+
# ),
326326
],
327327
)
328328
def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg):
@@ -589,8 +589,8 @@ def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
589589
trainer.strategy.barrier()
590590

591591

592-
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
593-
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
592+
# @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
593+
@pytest.mark.parametrize("wrap_min_params", [1024])
594594
def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params):
595595
"""Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy.
596596
@@ -807,3 +807,29 @@ def test_save_load_sharded_state_dict(tmp_path):
807807
strategy = FSDPStrategy(auto_wrap_policy={nn.Linear}, state_dict_type="sharded")
808808
trainer = Trainer(**trainer_kwargs, strategy=strategy)
809809
trainer.fit(model, ckpt_path=checkpoint_path)
810+
811+
812+
@RunIf(min_torch="1.12")
813+
@mock.patch("lightning.pytorch.strategies.fsdp.torch.load")
814+
@mock.patch("lightning.pytorch.strategies.fsdp._lazy_load")
815+
@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state")
816+
def test_fsdp_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_path):
817+
"""Test that loading a single file (full state) is lazy to reduce peak CPU memory usage."""
818+
model = BoringModel()
819+
checkpoint = {"state_dict": model.state_dict()}
820+
lazy_load_mock.return_value = checkpoint
821+
822+
strategy = FSDPStrategy()
823+
trainer = Trainer()
824+
model.trainer = trainer
825+
strategy._lightning_module = model
826+
strategy.model = model
827+
828+
file = tmp_path / "test.ckpt"
829+
file.touch()
830+
831+
strategy.load_checkpoint(checkpoint_path=file)
832+
if _TORCH_GREATER_EQUAL_2_0:
833+
lazy_load_mock.assert_called_once()
834+
else:
835+
torch_load_mock.assert_called_once()

0 commit comments

Comments
 (0)