1313# limitations under the License.
1414import io
1515import os
16- from typing import Any , Dict , List , Mapping , Optional , Sequence , Union
16+ from typing import Any , Dict , List , Mapping , Optional , Sequence , TYPE_CHECKING , Union
1717
1818import torch
1919from torch import Tensor
2020from torch .nn import Module
2121from torch .utils .data import DataLoader
2222
2323from lightning_lite .accelerators import Accelerator
24+ from lightning_lite .accelerators .tpu import _XLA_AVAILABLE
2425from lightning_lite .plugins .environments import XLAEnvironment
2526from lightning_lite .plugins .io .checkpoint_plugin import CheckpointIO
2627from lightning_lite .plugins .io .xla_plugin import XLACheckpointIO
2728from lightning_lite .plugins .precision import Precision
2829from lightning_lite .strategies .ddp_spawn import DDPSpawnStrategy
2930from lightning_lite .strategies .launchers .xla import _XLALauncher
3031from lightning_lite .strategies .strategy import TBroadcast
31- from lightning_lite .utilities import _TPU_AVAILABLE
3232from lightning_lite .utilities .apply_func import apply_to_collection
3333from lightning_lite .utilities .data import has_len
3434from lightning_lite .utilities .distributed import ReduceOp
3535from lightning_lite .utilities .rank_zero import rank_zero_only
3636from lightning_lite .utilities .types import _PATH
3737
38- if _TPU_AVAILABLE :
39- import torch_xla .core .xla_env_vars as xenv
40- import torch_xla .core .xla_model as xm
41- from torch_xla .core .xla_model import rendezvous
38+ if TYPE_CHECKING and _XLA_AVAILABLE :
4239 from torch_xla .distributed .parallel_loader import MpDeviceLoader
43- else :
44- xm , xmp , MpDeviceLoader , rendezvous = [None ] * 4
4540
4641
4742class XLAStrategy (DDPSpawnStrategy ):
@@ -71,6 +66,8 @@ def __init__(
7166 def root_device (self ) -> torch .device :
7267 if not self ._launched :
7368 raise RuntimeError ("Accessing the XLA device before processes have spawned is not allowed." )
69+ import torch_xla .core .xla_model as xm
70+
7471 return xm .xla_device ()
7572
7673 @property
@@ -89,6 +86,8 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]:
8986
9087 @property
9188 def is_distributed (self ) -> bool :
89+ import torch_xla .core .xla_env_vars as xenv
90+
9291 # HOST_WORLD_SIZE is not set outside the xmp.spawn process
9392 return (xenv .HOST_WORLD_SIZE in os .environ ) and self .world_size != 1
9493
@@ -106,8 +105,10 @@ def setup_module(self, module: Module) -> Module:
106105 def module_to_device (self , module : Module ) -> None :
107106 module .to (self .root_device )
108107
109- def process_dataloader (self , dataloader : DataLoader ) -> MpDeviceLoader :
108+ def process_dataloader (self , dataloader : DataLoader ) -> " MpDeviceLoader" :
110109 XLAStrategy ._validate_dataloader (dataloader )
110+ from torch_xla .distributed .parallel_loader import MpDeviceLoader
111+
111112 dataloader = MpDeviceLoader (dataloader , self .root_device )
112113 # Mimic interface to torch.utils.data.DataLoader
113114 dataloader .dataset = dataloader ._loader .dataset
@@ -126,6 +127,7 @@ def reduce(
126127 "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
127128 f" { reduce_op } "
128129 )
130+ import torch_xla .core .xla_model as xm
129131
130132 output = xm .mesh_reduce ("reduce" , output , sum )
131133
@@ -136,7 +138,9 @@ def reduce(
136138
137139 def barrier (self , name : Optional [str ] = None , * args : Any , ** kwargs : Any ) -> None :
138140 if self .is_distributed :
139- rendezvous (name )
141+ import torch_xla .core .xla_model as xm
142+
143+ xm .rendezvous (name )
140144
141145 def broadcast (self , obj : TBroadcast , src : int = 0 ) -> TBroadcast :
142146 if not self .is_distributed :
@@ -145,6 +149,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
145149 torch .save (obj , buffer )
146150 data = bytearray (buffer .getbuffer ())
147151 data_tensor = torch .tensor (data , device = self .root_device , dtype = torch .float )
152+ import torch_xla .core .xla_model as xm
153+
148154 data = xm .all_gather (data_tensor )
149155 buffer = io .BytesIO (data .cpu ().byte ().numpy ())
150156 obj = torch .load (buffer )
@@ -162,6 +168,8 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
162168 """
163169 if isinstance (tensor , Tensor ) and tensor .dim () == 0 :
164170 tensor = tensor .unsqueeze (0 )
171+ import torch_xla .core .xla_model as xm
172+
165173 return xm .all_gather (tensor )
166174
167175 def save_checkpoint (
0 commit comments