Skip to content
56 changes: 27 additions & 29 deletions pytorch_lightning/utilities/xla_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import os
import queue as q
import traceback
from multiprocessing import Process, Queue

import torch.multiprocessing as mp

Expand All @@ -26,31 +25,33 @@
import torch_xla.distributed.xla_multiprocessing as xmp

#: define waiting time got checking TPU available in sec
TPU_CHECK_TIMEOUT = 25
TPU_CHECK_TIMEOUT = 120


def inner_f(queue, func, *args, **kwargs): # pragma: no cover
try:
queue.put(func(*args, **kwargs))
# todo: specify the possible exception
except Exception:
traceback.print_exc()
queue.put(None)
def inner_f(index, queue, func, *args): # pragma: no cover
queue.put(func(index, *args))


def pl_multi_process(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
queue = Queue()
proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs)
proc.start()
proc.join(TPU_CHECK_TIMEOUT)
def wrapper(*args):
smp = mp.get_context("spawn")
queue = smp.Queue()
cxt = xmp.spawn(inner_f, args=(queue, func, *args), join=False)

# errors in the subprocesses are caught and saved in the error_queues
# inside the context, but we don't bother to check them.
if not cxt.join(TPU_CHECK_TIMEOUT):
for proc in cxt.processes:
if proc.is_alive():
proc.terminate()
proc.join()

try:
return queue.get_nowait()
except q.Empty:
traceback.print_exc()
return False
return None

return wrapper

Expand All @@ -61,26 +62,23 @@ class XLADeviceUtils:
_TPU_AVAILABLE = False

@staticmethod
@pl_multi_process
def _is_device_tpu() -> bool:
def _is_device_tpu(index) -> bool:
"""
Check if device is TPU

Return:
A boolean value indicating if the xla device is a TPU device or not
"""
if not _XLA_AVAILABLE:
return False

def _fn(_: int, mp_queue):
try:
device = xm.xla_device()
mp_queue.put(device.type == 'xla')
except Exception:
mp_queue.put(False)
try:
device = xm.xla_device()
return device.type == 'xla'

smp = mp.get_context("spawn")
queue = smp.SimpleQueue()
xmp.spawn(_fn, args=(queue, ), nprocs=1)
return queue.get()
except RuntimeError:
traceback.print_exc()
return False

@staticmethod
def xla_available() -> bool:
Expand All @@ -105,7 +103,7 @@ def tpu_device_exists() -> bool:

if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE:

XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu()
XLADeviceUtils._TPU_AVAILABLE = bool(pl_multi_process(XLADeviceUtils._is_device_tpu)())

if XLADeviceUtils._TPU_AVAILABLE:
os.environ["PL_TPU_AVAILABLE"] = '1'
Expand Down