Skip to content

Commit aa92c4c

Browse files
tchatonYour Namekaushikb11carmocca
authored andcommitted
[TPU] update is_tpu_exists utils internal logic to rely on xmp.spawn (#6719)
* update_logic * update * Update tests/utilities/test_xla_device_utils.py * Update pytorch_lightning/utilities/xla_device.py Co-authored-by: Kaushik B <[email protected]> * Update pytorch_lightning/utilities/xla_device.py Co-authored-by: Kaushik B <[email protected]> * update test * Update tests/utilities/test_xla_device_utils.py * update * Apply fix * Docstring * flake8 * update Co-authored-by: Your Name <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 6afc931 commit aa92c4c

File tree

5 files changed

+54
-36
lines changed

5 files changed

+54
-36
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020
- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
2121
- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
2222
- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657))
23+
- Fixed bug where no TPUs were detected in a TPU pod env ([#6719](https://github.com/PyTorchLightning/pytorch-lightning/pull/6719))
2324

2425

2526
## [1.2.5] - 2021-03-23

pytorch_lightning/utilities/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""General utilities"""
1515

1616
import numpy
17-
1817
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
1918
from pytorch_lightning.utilities.distributed import ( # noqa: F401
2019
AllGatherGrad,

pytorch_lightning/utilities/xla_device.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15+
import os
1516
import queue as q
1617
import traceback
1718
from multiprocessing import Process, Queue
1819

19-
import torch
20+
import torch.multiprocessing as mp
2021

2122
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
2223

2324
if _XLA_AVAILABLE:
2425
import torch_xla.core.xla_model as xm
26+
import torch_xla.distributed.xla_multiprocessing as xmp
27+
2528
#: define waiting time got checking TPU available in sec
26-
TPU_CHECK_TIMEOUT = 100
29+
TPU_CHECK_TIMEOUT = 25
2730

2831

2932
def inner_f(queue, func, *args, **kwargs): # pragma: no cover
@@ -55,34 +58,29 @@ def wrapper(*args, **kwargs):
5558
class XLADeviceUtils:
5659
"""Used to detect the type of XLA device"""
5760

58-
TPU_AVAILABLE = None
59-
60-
@staticmethod
61-
def _fetch_xla_device_type(device: torch.device) -> str:
62-
"""
63-
Returns XLA device type
64-
65-
Args:
66-
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0
67-
68-
Return:
69-
Returns a str of the device hardware type. i.e TPU
70-
"""
71-
if _XLA_AVAILABLE:
72-
return xm.xla_device_hw(device)
61+
_TPU_AVAILABLE = False
7362

7463
@staticmethod
64+
@pl_multi_process
7565
def _is_device_tpu() -> bool:
7666
"""
7767
Check if device is TPU
7868
7969
Return:
8070
A boolean value indicating if the xla device is a TPU device or not
8171
"""
82-
if _XLA_AVAILABLE:
83-
device = xm.xla_device()
84-
device_type = XLADeviceUtils._fetch_xla_device_type(device)
85-
return device_type == "TPU"
72+
73+
def _fn(_: int, mp_queue):
74+
try:
75+
device = xm.xla_device()
76+
mp_queue.put(device.type == 'xla')
77+
except Exception:
78+
mp_queue.put(False)
79+
80+
smp = mp.get_context("spawn")
81+
queue = smp.SimpleQueue()
82+
xmp.spawn(_fn, args=(queue, ), nprocs=1)
83+
return queue.get()
8684

8785
@staticmethod
8886
def xla_available() -> bool:
@@ -102,6 +100,14 @@ def tpu_device_exists() -> bool:
102100
Return:
103101
A boolean value indicating if a TPU device exists on the system
104102
"""
105-
if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE:
106-
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
107-
return XLADeviceUtils.TPU_AVAILABLE
103+
if os.getenv("PL_TPU_AVAILABLE", '0') == "1":
104+
XLADeviceUtils._TPU_AVAILABLE = True
105+
106+
if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE:
107+
108+
XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu()
109+
110+
if XLADeviceUtils._TPU_AVAILABLE:
111+
os.environ["PL_TPU_AVAILABLE"] = '1'
112+
113+
return XLADeviceUtils._TPU_AVAILABLE

tests/plugins/test_custom_plugin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pytest
15+
import torch
16+
1417
from pytorch_lightning import Trainer
1518
from pytorch_lightning.plugins import DDPPlugin
1619
from tests.helpers import BoringModel
@@ -26,6 +29,7 @@ def __init__(self, **kwargs):
2629

2730

2831
@RunIf(skip_windows=True)
32+
@pytest.mark.skipif(torch.cuda.is_available(), reason="RuntimeError: Tensors must be CUDA and dense")
2933
def test_sync_batchnorm_set(tmpdir):
3034
"""Tests if sync_batchnorm is automatically set for custom plugin."""
3135
model = BoringModel()

tests/utilities/test_xla_device_utils.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,37 @@
1717
import pytest
1818

1919
import pytorch_lightning.utilities.xla_device as xla_utils
20-
from pytorch_lightning.utilities import _TPU_AVAILABLE, _XLA_AVAILABLE
21-
from tests.helpers.utils import pl_multi_process_test
20+
from pytorch_lightning.utilities import _XLA_AVAILABLE
21+
from tests.helpers.runif import RunIf
2222

2323

2424
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
2525
def test_tpu_device_absence():
26-
"""Check tpu_device_exists returns None when torch_xla is not available"""
27-
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
26+
"""Check tpu_device_exists returns False when torch_xla is not available"""
27+
assert not xla_utils.XLADeviceUtils.tpu_device_exists()
2828

2929

30-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires torch_xla to be installed")
31-
@pl_multi_process_test
30+
@RunIf(tpu=True)
3231
def test_tpu_device_presence():
3332
"""Check tpu_device_exists returns True when TPU is available"""
34-
assert xla_utils.XLADeviceUtils.tpu_device_exists() is True
33+
assert xla_utils.XLADeviceUtils.tpu_device_exists()
3534

3635

37-
@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 10)
36+
def sleep_fn(sleep_time: float) -> bool:
37+
time.sleep(sleep_time)
38+
return True
39+
40+
41+
@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3)
42+
@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present")
3843
def test_result_returns_within_timeout_seconds():
39-
"""Check that pl_multi_process returns within 10 seconds"""
44+
"""Check that pl_multi_process returns within 3 seconds"""
45+
fn = xla_utils.pl_multi_process(sleep_fn)
46+
4047
start = time.time()
41-
result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25)
48+
result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5)
4249
end = time.time()
4350
elapsed_time = int(end - start)
51+
4452
assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT
45-
assert result is False
53+
assert result

0 commit comments

Comments
 (0)