@@ -294,35 +294,35 @@ def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params):
294294 assert all (_ex == _co for _ex , _co in zip (full_state_dict .keys (), correct_state_dict .keys ()))
295295
296296
297- @RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True , min_torch = "1.12" )
297+ # @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
298298@pytest .mark .parametrize (
299299 ("model" , "strategy" , "strategy_cfg" ),
300300 [
301301 pytest .param (TestFSDPModel (), "fsdp" , None , id = "manually_wrapped" ),
302- pytest .param (
303- TestFSDPModelAutoWrapped (),
304- FSDPStrategy ,
305- {"auto_wrap_policy" : custom_auto_wrap_policy },
306- marks = RunIf (max_torch = "2.0.0" ),
307- id = "autowrap_1x" ,
308- ),
309- pytest .param (
310- TestFSDPModelAutoWrapped (),
311- FSDPStrategy ,
312- {"auto_wrap_policy" : custom_auto_wrap_policy },
313- marks = RunIf (min_torch = "2.0.0" ),
314- id = "autowrap_2x" ,
315- ),
316- pytest .param (
317- TestFSDPModelAutoWrapped (),
318- FSDPStrategy ,
319- {
320- "auto_wrap_policy" : ModuleWrapPolicy ({nn .Linear }) if _TORCH_GREATER_EQUAL_2_1 else None ,
321- "use_orig_params" : True ,
322- },
323- marks = RunIf (min_torch = "2.1.0" ),
324- id = "autowrap_use_orig_params" ,
325- ),
302+ # pytest.param(
303+ # TestFSDPModelAutoWrapped(),
304+ # FSDPStrategy,
305+ # {"auto_wrap_policy": custom_auto_wrap_policy},
306+ # marks=RunIf(max_torch="2.0.0"),
307+ # id="autowrap_1x",
308+ # ),
309+ # pytest.param(
310+ # TestFSDPModelAutoWrapped(),
311+ # FSDPStrategy,
312+ # {"auto_wrap_policy": custom_auto_wrap_policy},
313+ # marks=RunIf(min_torch="2.0.0"),
314+ # id="autowrap_2x",
315+ # ),
316+ # pytest.param(
317+ # TestFSDPModelAutoWrapped(),
318+ # FSDPStrategy,
319+ # {
320+ # "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}) if _TORCH_GREATER_EQUAL_2_1 else None,
321+ # "use_orig_params": True,
322+ # },
323+ # marks=RunIf(min_torch="2.1.0"),
324+ # id="autowrap_use_orig_params",
325+ # ),
326326 ],
327327)
328328def test_fsdp_checkpoint_multi_gpus (tmpdir , model , strategy , strategy_cfg ):
@@ -589,8 +589,8 @@ def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
589589 trainer .strategy .barrier ()
590590
591591
592- @RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True , min_torch = "1.12" )
593- @pytest .mark .parametrize ("wrap_min_params" , [2 , 1024 , 100000000 ])
592+ # @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
593+ @pytest .mark .parametrize ("wrap_min_params" , [1024 ])
594594def test_fsdp_strategy_load_optimizer_states (tmpdir , wrap_min_params ):
595595 """Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy.
596596
@@ -807,3 +807,29 @@ def test_save_load_sharded_state_dict(tmp_path):
807807 strategy = FSDPStrategy (auto_wrap_policy = {nn .Linear }, state_dict_type = "sharded" )
808808 trainer = Trainer (** trainer_kwargs , strategy = strategy )
809809 trainer .fit (model , ckpt_path = checkpoint_path )
810+
811+
812+ @RunIf (min_torch = "1.12" )
813+ @mock .patch ("lightning.pytorch.strategies.fsdp.torch.load" )
814+ @mock .patch ("lightning.pytorch.strategies.fsdp._lazy_load" )
815+ @mock .patch ("lightning.pytorch.strategies.fsdp._load_raw_module_state" )
816+ def test_fsdp_lazy_load_full_state_dict (_ , lazy_load_mock , torch_load_mock , tmp_path ):
817+ """Test that loading a single file (full state) is lazy to reduce peak CPU memory usage."""
818+ model = BoringModel ()
819+ checkpoint = {"state_dict" : model .state_dict ()}
820+ lazy_load_mock .return_value = checkpoint
821+
822+ strategy = FSDPStrategy ()
823+ trainer = Trainer ()
824+ model .trainer = trainer
825+ strategy ._lightning_module = model
826+ strategy .model = model
827+
828+ file = tmp_path / "test.ckpt"
829+ file .touch ()
830+
831+ strategy .load_checkpoint (checkpoint_path = file )
832+ if _TORCH_GREATER_EQUAL_2_0 :
833+ lazy_load_mock .assert_called_once ()
834+ else :
835+ torch_load_mock .assert_called_once ()
0 commit comments