Skip to content

Commit 036ecf2

Browse files
committed
Update tpu.py
1 parent 1203094 commit 036ecf2

File tree

1 file changed

+11
-1
lines changed
  • pytorch_lightning/accelerators

1 file changed

+11
-1
lines changed

pytorch_lightning/accelerators/tpu.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import torch
1717

1818
from pytorch_lightning.accelerators.accelerator import Accelerator
19-
from pytorch_lightning.utilities import _XLA_AVAILABLE
19+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
20+
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
2021

2122
if _XLA_AVAILABLE:
2223
import torch_xla.core.xla_model as xm
@@ -25,6 +26,15 @@
2526
class TPUAccelerator(Accelerator):
2627
"""Accelerator for TPU devices."""
2728

29+
def setup_environment(self, root_device: torch.device) -> None:
30+
"""
31+
Raises:
32+
MisconfigurationException:
33+
If the TPU device is not available.
34+
"""
35+
if not _XLA_AVAILABLE:
36+
raise MisconfigurationException("The TPU Accelerator requires torch_xla and a TPU device to run.")
37+
2838
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
2939
"""Gets stats for the given TPU device.
3040

0 commit comments

Comments
 (0)