From 3bb08b94f5d06464d7e09dea4333e0393b28e5f2 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 9 Jun 2022 10:28:12 -0700 Subject: [PATCH] Fix torch.distributed._sharded_tensor DeprecationWarning --- src/pytorch_lightning/core/module.py | 7 +++++-- tests/tests_pytorch/core/test_lightning_module.py | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 4d400ef15d329..613688ee5b440 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -45,7 +45,7 @@ from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.parsing import collect_init_args from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -1991,7 +1991,10 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: rank_zero_debug("Could not register sharded tensor state dict hooks") return - from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook + if _TORCH_GREATER_EQUAL_1_11: + from torch.distributed._shard.sharded_tensor import pre_load_state_dict_hook, state_dict_hook + else: + from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook self._register_state_dict_hook(state_dict_hook) diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 639863f4c1c72..373376191ca11 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -299,7 +299,10 @@ def assert_device(device: torch.device) -> None: @RunIf(min_torch="1.10", skip_windows=True) def test_sharded_tensor_state_dict(single_process_pg): - from torch.distributed._sharded_tensor import empty as sharded_tensor_empty + if _TORCH_GREATER_EQUAL_1_11: + from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty + else: + from torch.distributed._sharded_tensor import empty as sharded_tensor_empty from torch.distributed._sharding_spec import ChunkShardingSpec class BoringModelWithShardedTensor(BoringModel):