Skip to content

Commit d319270

Browse files
committed
materialize
1 parent 0929ed7 commit d319270

File tree

1 file changed

+7
-2
lines changed
  • src/lightning/pytorch/strategies

1 file changed

+7
-2
lines changed

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
_TORCH_GREATER_EQUAL_2_0,
5454
)
5555
from lightning.fabric.utilities.init import _EmptyInit
56-
from lightning.fabric.utilities.load import _lazy_load
56+
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
5757
from lightning.fabric.utilities.optimizer import _optimizers_to_device
5858
from lightning.fabric.utilities.seed import reset_seed
5959
from lightning.fabric.utilities.types import _PATH, ProcessGroup, ReduceOp
@@ -574,7 +574,12 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
574574

575575
if _is_full_checkpoint(path):
576576
checkpoint = _lazy_load(path) if _TORCH_GREATER_EQUAL_2_0 else torch.load(path, map_location="cpu")
577-
_load_raw_module_state(checkpoint["state_dict"], module=self.model, world_size=self.world_size)
577+
_load_raw_module_state(checkpoint.pop("state_dict"), module=self.model, world_size=self.world_size)
578+
579+
if _TORCH_GREATER_EQUAL_2_0:
580+
# Materialize lazy tensors if there are any left in the checkpoint
581+
# The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues
582+
checkpoint = _materialize_tensors(checkpoint)
578583

579584
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
580585
from torch.distributed.fsdp import OptimStateKeyType

0 commit comments

Comments
 (0)