Skip to content

Commit b212a48

Browse files
committed
Refactor XLADeviceUtils
1 parent a5b0f8b commit b212a48

File tree

30 files changed

+257
-262
lines changed

30 files changed

+257
-262
lines changed

docs/source-lit/conf.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,6 @@ def find_source():
406406
from pytorch_lightning.cli import LightningCLI
407407
from pytorch_lightning.utilities import (
408408
_APEX_AVAILABLE,
409-
_XLA_AVAILABLE,
410-
_TPU_AVAILABLE,
411409
_TORCHVISION_AVAILABLE,
412410
_TORCH_GREATER_EQUAL_1_10,
413411
)

docs/source-pytorch/conf.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,6 @@ def package_list_from_file(file):
394394
from pytorch_lightning.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
395395
from pytorch_lightning.utilities import (
396396
_APEX_AVAILABLE,
397-
_XLA_AVAILABLE,
398-
_TPU_AVAILABLE,
399397
_TORCHVISION_AVAILABLE,
400398
_TORCH_GREATER_EQUAL_1_10,
401399
)

src/lightning_lite/accelerators/tpu.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,27 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, List, Optional, Union
14+
import functools
15+
import queue as q
16+
import traceback
17+
from multiprocessing import Process, Queue
18+
from typing import Any, Callable, Dict, List, Optional, Union
1519

1620
import torch
21+
from lightning_utilities.core.imports import RequirementCache
1722

1823
from lightning_lite.accelerators.accelerator import Accelerator
1924
from lightning_lite.utilities.device_parser import _check_data_type
20-
from lightning_lite.utilities.imports import _TPU_AVAILABLE
2125

2226

2327
class TPUAccelerator(Accelerator):
2428
"""Accelerator for TPU devices."""
2529

30+
def __init__(self, *args: Any, **kwargs: Any) -> None:
31+
if not _XLA_AVAILABLE:
32+
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
33+
super().__init__(*args, **kwargs)
34+
2635
def setup_device(self, device: torch.device) -> None:
2736
pass
2837

@@ -47,8 +56,10 @@ def auto_device_count() -> int:
4756
return 8
4857

4958
@staticmethod
59+
@functools.lru_cache(maxsize=1)
5060
def is_available() -> bool:
51-
return _TPU_AVAILABLE
61+
# check `_XLA_AVAILABLE` again to avoid launching processes
62+
return _XLA_AVAILABLE and _is_device_tpu()
5263

5364
@classmethod
5465
def register_accelerators(cls, accelerator_registry: Dict) -> None:
@@ -59,6 +70,64 @@ def register_accelerators(cls, accelerator_registry: Dict) -> None:
5970
)
6071

6172

73+
# define TPU availability timeout in seconds
74+
TPU_CHECK_TIMEOUT = 60
75+
76+
77+
def _inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover
78+
try:
79+
queue.put(func(*args, **kwargs))
80+
except Exception:
81+
traceback.print_exc()
82+
queue.put(None)
83+
84+
85+
def _multi_process(func: Callable) -> Callable:
86+
@functools.wraps(func)
87+
def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]:
88+
queue: Queue = Queue()
89+
proc = Process(target=_inner_f, args=(queue, func, *args), kwargs=kwargs)
90+
proc.start()
91+
proc.join(TPU_CHECK_TIMEOUT)
92+
try:
93+
return queue.get_nowait()
94+
except q.Empty:
95+
traceback.print_exc()
96+
return False
97+
98+
return wrapper
99+
100+
101+
@_multi_process
102+
def _is_device_tpu() -> bool:
103+
"""Check if TPU devices are available. Runs XLA device check within a separate process.
104+
105+
Return:
106+
A boolean value indicating if TPU devices are available
107+
"""
108+
if not _XLA_AVAILABLE:
109+
return False
110+
import torch_xla.core.xla_model as xm
111+
112+
# For the TPU Pod training process, for example, if we have
113+
# TPU v3-32 with 4 VMs, the world size would be 4 and as
114+
# we would have to use `torch_xla.distributed.xla_dist` for
115+
# multiple VMs and TPU_CONFIG won't be available, running
116+
# `xm.get_xla_supported_devices("TPU")` won't be possible.
117+
return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU"))
118+
119+
120+
_XLA_AVAILABLE = RequirementCache("torch_xla")
121+
122+
123+
def tpu_distributed() -> bool:
124+
if not TPUAccelerator.is_available():
125+
return False
126+
import torch_xla.core.xla_model as xm
127+
128+
return xm.xrt_world_size() > 1
129+
130+
62131
def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]:
63132
"""
64133
Parses the tpu_cores given in the format as accepted by the

src/lightning_lite/plugins/environments/xla_environment.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,10 @@
1313
# limitations under the License.
1414
import logging
1515
import os
16+
from typing import Any
1617

18+
from lightning_lite.accelerators.tpu import _XLA_AVAILABLE, TPUAccelerator
1719
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
18-
from lightning_lite.utilities.imports import _TPU_AVAILABLE
19-
20-
if _TPU_AVAILABLE:
21-
import torch_xla.core.xla_env_vars as xenv
22-
import torch_xla.core.xla_model as xm
2320

2421
log = logging.getLogger(__name__)
2522

@@ -31,36 +28,53 @@ class XLAEnvironment(ClusterEnvironment):
3128
`here <https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_env_vars.py>`_.
3229
"""
3330

31+
def __init__(self, *args: Any, **kwargs: Any) -> None:
32+
if not _XLA_AVAILABLE:
33+
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
34+
super().__init__(*args, **kwargs)
35+
3436
@property
3537
def creates_processes_externally(self) -> bool:
3638
return False
3739

3840
@property
3941
def main_address(self) -> str:
42+
import torch_xla.core.xla_env_vars as xenv
43+
4044
return os.environ[xenv.TPU_MESH_CTLER_ADDR]
4145

4246
@property
4347
def main_port(self) -> int:
48+
import torch_xla.core.xla_env_vars as xenv
49+
4450
return int(os.environ[xenv.TPU_MESH_CTLER_PORT])
4551

4652
@staticmethod
4753
def detect() -> bool:
48-
return _TPU_AVAILABLE
54+
return TPUAccelerator.is_available()
4955

5056
def world_size(self) -> int:
57+
import torch_xla.core.xla_model as xm
58+
5159
return xm.xrt_world_size()
5260

5361
def set_world_size(self, size: int) -> None:
5462
log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
5563

5664
def global_rank(self) -> int:
65+
import torch_xla.core.xla_model as xm
66+
5767
return xm.get_ordinal()
5868

5969
def set_global_rank(self, rank: int) -> None:
6070
log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
6171

6272
def local_rank(self) -> int:
73+
import torch_xla.core.xla_model as xm
74+
6375
return xm.get_local_ordinal()
6476

6577
def node_rank(self) -> int:
78+
import torch_xla.core.xla_env_vars as xenv
79+
6680
return int(os.environ.get(xenv.HOST_ORDINAL, 0))

src/lightning_lite/plugins/io/xla_plugin.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,24 @@
1616

1717
from lightning_utilities.core.apply_func import apply_to_collection
1818

19+
from lightning_lite.accelerators.tpu import _XLA_AVAILABLE
1920
from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO
2021
from lightning_lite.utilities.cloud_io import get_filesystem
21-
from lightning_lite.utilities.imports import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
22+
from lightning_lite.utilities.imports import _OMEGACONF_AVAILABLE
2223
from lightning_lite.utilities.types import _PATH
2324

24-
if _TPU_AVAILABLE:
25-
import torch_xla.core.xla_model as xm
26-
2725
if _OMEGACONF_AVAILABLE:
2826
from omegaconf import DictConfig, ListConfig, OmegaConf
2927

3028

3129
class XLACheckpointIO(TorchCheckpointIO):
3230
"""CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies."""
3331

32+
def __init__(self, *args: Any, **kwargs: Any) -> None:
33+
if not _XLA_AVAILABLE:
34+
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
35+
super().__init__(*args, **kwargs)
36+
3437
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
3538
"""Save model/training states as a checkpoint file through state-dump and file-write.
3639
@@ -55,4 +58,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
5558
# Ref: https://github.com/pytorch/xla/issues/2773
5659
if _OMEGACONF_AVAILABLE:
5760
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
61+
import torch_xla.core.xla_model as xm
62+
5863
xm.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, path)

src/lightning_lite/strategies/launchers/xla.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,10 @@
1919
import torch.multiprocessing as mp
2020
from torch.multiprocessing import ProcessContext
2121

22+
from lightning_lite.accelerators.tpu import _XLA_AVAILABLE
2223
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
23-
from lightning_lite.utilities import _TPU_AVAILABLE
2424
from lightning_lite.utilities.apply_func import move_data_to_device
2525

26-
if _TPU_AVAILABLE:
27-
import torch_xla.core.xla_model as xm
28-
import torch_xla.distributed.xla_multiprocessing as xmp
29-
else:
30-
xm, xmp = None, None
31-
3226
if TYPE_CHECKING:
3327
from lightning_lite.strategies import Strategy
3428

@@ -50,6 +44,8 @@ class _XLALauncher(_MultiProcessingLauncher):
5044
"""
5145

5246
def __init__(self, strategy: "Strategy") -> None:
47+
if not _XLA_AVAILABLE:
48+
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
5349
super().__init__(strategy=strategy, start_method="fork")
5450

5551
@property
@@ -103,6 +99,8 @@ def _save_spawn(
10399
) -> Optional[ProcessContext]:
104100
"""Wraps the :func:`torch_xla.distributed.xla_multiprocessing.spawn` with added teardown logic for the worker
105101
processes."""
102+
import torch_xla.core.xla_model as xm
103+
import torch_xla.distributed.xla_multiprocessing as xmp
106104

107105
@wraps(fn)
108106
def wrapped(rank: int, *_args: Any) -> None:

src/lightning_lite/strategies/xla.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,30 @@
1313
# limitations under the License.
1414
import io
1515
import os
16-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
16+
from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union
1717

1818
import torch
1919
from torch import Tensor
2020
from torch.nn import Module
2121
from torch.utils.data import DataLoader
2222

2323
from lightning_lite.accelerators import Accelerator
24+
from lightning_lite.accelerators.tpu import _XLA_AVAILABLE
2425
from lightning_lite.plugins.environments import XLAEnvironment
2526
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO
2627
from lightning_lite.plugins.io.xla_plugin import XLACheckpointIO
2728
from lightning_lite.plugins.precision import Precision
2829
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy
2930
from lightning_lite.strategies.launchers.xla import _XLALauncher
3031
from lightning_lite.strategies.strategy import TBroadcast
31-
from lightning_lite.utilities import _TPU_AVAILABLE
3232
from lightning_lite.utilities.apply_func import apply_to_collection
3333
from lightning_lite.utilities.data import has_len
3434
from lightning_lite.utilities.distributed import ReduceOp
3535
from lightning_lite.utilities.rank_zero import rank_zero_only
3636
from lightning_lite.utilities.types import _PATH
3737

38-
if _TPU_AVAILABLE:
39-
import torch_xla.core.xla_env_vars as xenv
40-
import torch_xla.core.xla_model as xm
41-
from torch_xla.core.xla_model import rendezvous
38+
if TYPE_CHECKING and _XLA_AVAILABLE:
4239
from torch_xla.distributed.parallel_loader import MpDeviceLoader
43-
else:
44-
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4
4540

4641

4742
class XLAStrategy(DDPSpawnStrategy):
@@ -71,6 +66,8 @@ def __init__(
7166
def root_device(self) -> torch.device:
7267
if not self._launched:
7368
raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
69+
import torch_xla.core.xla_model as xm
70+
7471
return xm.xla_device()
7572

7673
@property
@@ -89,6 +86,8 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]:
8986

9087
@property
9188
def is_distributed(self) -> bool:
89+
import torch_xla.core.xla_env_vars as xenv
90+
9291
# HOST_WORLD_SIZE is not set outside the xmp.spawn process
9392
return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1
9493

@@ -106,8 +105,10 @@ def setup_module(self, module: Module) -> Module:
106105
def module_to_device(self, module: Module) -> None:
107106
module.to(self.root_device)
108107

109-
def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
108+
def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader":
110109
XLAStrategy._validate_dataloader(dataloader)
110+
from torch_xla.distributed.parallel_loader import MpDeviceLoader
111+
111112
dataloader = MpDeviceLoader(dataloader, self.root_device)
112113
# Mimic interface to torch.utils.data.DataLoader
113114
dataloader.dataset = dataloader._loader.dataset
@@ -126,6 +127,7 @@ def reduce(
126127
"Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
127128
f" {reduce_op}"
128129
)
130+
import torch_xla.core.xla_model as xm
129131

130132
output = xm.mesh_reduce("reduce", output, sum)
131133

@@ -136,7 +138,9 @@ def reduce(
136138

137139
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
138140
if self.is_distributed:
139-
rendezvous(name)
141+
import torch_xla.core.xla_model as xm
142+
143+
xm.rendezvous(name)
140144

141145
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
142146
if not self.is_distributed:
@@ -145,6 +149,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
145149
torch.save(obj, buffer)
146150
data = bytearray(buffer.getbuffer())
147151
data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float)
152+
import torch_xla.core.xla_model as xm
153+
148154
data = xm.all_gather(data_tensor)
149155
buffer = io.BytesIO(data.cpu().byte().numpy())
150156
obj = torch.load(buffer)
@@ -162,6 +168,8 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
162168
"""
163169
if isinstance(tensor, Tensor) and tensor.dim() == 0:
164170
tensor = tensor.unsqueeze(0)
171+
import torch_xla.core.xla_model as xm
172+
165173
return xm.all_gather(tensor)
166174

167175
def save_checkpoint(

src/lightning_lite/utilities/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
_TORCH_GREATER_EQUAL_1_10,
3030
_TORCH_GREATER_EQUAL_1_11,
3131
_TORCH_GREATER_EQUAL_1_12,
32-
_TPU_AVAILABLE,
33-
_XLA_AVAILABLE,
3432
)
3533
from lightning_lite.utilities.rank_zero import ( # noqa: F401
3634
rank_zero_deprecation,

0 commit comments

Comments
 (0)