diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 294d3d2c5ec40..6daed6990c386 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -15,7 +15,6 @@ import os import queue as q import traceback -from multiprocessing import Process, Queue import torch.multiprocessing as mp @@ -26,31 +25,33 @@ import torch_xla.distributed.xla_multiprocessing as xmp #: define waiting time got checking TPU available in sec -TPU_CHECK_TIMEOUT = 25 +TPU_CHECK_TIMEOUT = 120 -def inner_f(queue, func, *args, **kwargs): # pragma: no cover - try: - queue.put(func(*args, **kwargs)) - # todo: specify the possible exception - except Exception: - traceback.print_exc() - queue.put(None) +def inner_f(index, queue, func, *args): # pragma: no cover + queue.put(func(index, *args)) def pl_multi_process(func): @functools.wraps(func) - def wrapper(*args, **kwargs): - queue = Queue() - proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) - proc.start() - proc.join(TPU_CHECK_TIMEOUT) + def wrapper(*args): + smp = mp.get_context("spawn") + queue = smp.Queue() + cxt = xmp.spawn(inner_f, args=(queue, func, *args), join=False) + + # errors in the subprocesses are caught and saved in the error_queues + # inside the context, but we don't bother to check them. + if not cxt.join(TPU_CHECK_TIMEOUT): + for proc in cxt.processes: + if proc.is_alive(): + proc.terminate() + proc.join() + try: return queue.get_nowait() except q.Empty: - traceback.print_exc() - return False + return None return wrapper @@ -61,26 +62,23 @@ class XLADeviceUtils: _TPU_AVAILABLE = False @staticmethod - @pl_multi_process - def _is_device_tpu() -> bool: + def _is_device_tpu(index) -> bool: """ Check if device is TPU Return: A boolean value indicating if the xla device is a TPU device or not """ + if not _XLA_AVAILABLE: + return False - def _fn(_: int, mp_queue): - try: - device = xm.xla_device() - mp_queue.put(device.type == 'xla') - except Exception: - mp_queue.put(False) + try: + device = xm.xla_device() + return device.type == 'xla' - smp = mp.get_context("spawn") - queue = smp.SimpleQueue() - xmp.spawn(_fn, args=(queue, ), nprocs=1) - return queue.get() + except RuntimeError: + traceback.print_exc() + return False @staticmethod def xla_available() -> bool: @@ -105,7 +103,7 @@ def tpu_device_exists() -> bool: if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE: - XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu() + XLADeviceUtils._TPU_AVAILABLE = bool(pl_multi_process(XLADeviceUtils._is_device_tpu)()) if XLADeviceUtils._TPU_AVAILABLE: os.environ["PL_TPU_AVAILABLE"] = '1'