Skip to content

Commit ce5420a

Browse files
kaushikb11lexierule
authored andcommitted
Update logic for checking TPUs availability (#6767)
* Update logic for checking TPUs availability * fix flake8 * add fix
1 parent 523db14 commit ce5420a

File tree

2 files changed

+7
-16
lines changed

2 files changed

+7
-16
lines changed

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import io
22
import os
33
import re
4+
import time
45
from typing import Any, Dict, Iterable, List, Optional, Union
56

67
import torch
@@ -106,6 +107,9 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
106107
self.__save_end_of_training_weights(self.lightning_module)
107108
self.transfer_distrib_spawn_state_on_fit_end(results)
108109

110+
if self.global_rank == 0:
111+
time.sleep(2)
112+
109113
self.barrier("end-process")
110114

111115
def __save_end_of_training_weights(self, model: LightningModule) -> None:

pytorch_lightning/utilities/xla_device.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,10 @@
1717
import traceback
1818
from multiprocessing import Process, Queue
1919

20-
import torch.multiprocessing as mp
21-
2220
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
2321

2422
if _XLA_AVAILABLE:
2523
import torch_xla.core.xla_model as xm
26-
import torch_xla.distributed.xla_multiprocessing as xmp
2724

2825
#: define waiting time got checking TPU available in sec
2926
TPU_CHECK_TIMEOUT = 25
@@ -64,23 +61,13 @@ class XLADeviceUtils:
6461
@pl_multi_process
6562
def _is_device_tpu() -> bool:
6663
"""
67-
Check if device is TPU
64+
Check if TPU devices are available
6865
6966
Return:
70-
A boolean value indicating if the xla device is a TPU device or not
67+
A boolean value indicating if TPU devices are available
7168
"""
7269

73-
def _fn(_: int, mp_queue):
74-
try:
75-
device = xm.xla_device()
76-
mp_queue.put(device.type == 'xla')
77-
except Exception:
78-
mp_queue.put(False)
79-
80-
smp = mp.get_context("spawn")
81-
queue = smp.SimpleQueue()
82-
xmp.spawn(_fn, args=(queue, ), nprocs=1)
83-
return queue.get()
70+
return len(xm.get_xla_supported_devices("TPU")) > 0
8471

8572
@staticmethod
8673
def xla_available() -> bool:

0 commit comments

Comments
 (0)