Skip to content

Commit 0929ed7

Browse files
committed
update
1 parent f4825e5 commit 0929ed7

File tree

1 file changed

+3
-3
lines changed
  • src/lightning/pytorch/strategies

1 file changed

+3
-3
lines changed

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
_TORCH_GREATER_EQUAL_2_0,
5454
)
5555
from lightning.fabric.utilities.init import _EmptyInit
56+
from lightning.fabric.utilities.load import _lazy_load
5657
from lightning.fabric.utilities.optimizer import _optimizers_to_device
5758
from lightning.fabric.utilities.seed import reset_seed
5859
from lightning.fabric.utilities.types import _PATH, ProcessGroup, ReduceOp
@@ -572,9 +573,8 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
572573
return metadata
573574

574575
if _is_full_checkpoint(path):
575-
# TODO: Support lazy-loading here (see Fabric)
576-
checkpoint = torch.load(path, map_location="cpu")
577-
_load_raw_module_state(checkpoint["state_dict"], world_size=self.world_size, module=self.model)
576+
checkpoint = _lazy_load(path) if _TORCH_GREATER_EQUAL_2_0 else torch.load(path, map_location="cpu")
577+
_load_raw_module_state(checkpoint["state_dict"], module=self.model, world_size=self.world_size)
578578

579579
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
580580
from torch.distributed.fsdp import OptimStateKeyType

0 commit comments

Comments
 (0)