Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.parsing import collect_init_args
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -1991,7 +1991,10 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
rank_zero_debug("Could not register sharded tensor state dict hooks")
return

from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook
if _TORCH_GREATER_EQUAL_1_11:
from torch.distributed._shard.sharded_tensor import pre_load_state_dict_hook, state_dict_hook
else:
from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook

self._register_state_dict_hook(state_dict_hook)

Expand Down
5 changes: 4 additions & 1 deletion tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,10 @@ def assert_device(device: torch.device) -> None:

@RunIf(min_torch="1.10", skip_windows=True)
def test_sharded_tensor_state_dict(single_process_pg):
from torch.distributed._sharded_tensor import empty as sharded_tensor_empty
if _TORCH_GREATER_EQUAL_1_11:
from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty
else:
from torch.distributed._sharded_tensor import empty as sharded_tensor_empty
from torch.distributed._sharding_spec import ChunkShardingSpec

class BoringModelWithShardedTensor(BoringModel):
Expand Down