Skip to content

Commit 839813e

Browse files
lezwonBordaawaelchli
authored
timeout for tpu check (#4340)
* timeout for tpu check * added tests * updated CHANGELOG.md * fixed windows tests * Update pytorch_lightning/utilities/xla_device_utils.py Co-authored-by: Jirka Borovec <[email protected]> * requested changes Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 6211fd4 commit 839813e

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4848

4949
- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))
5050

51+
- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))
52+
5153
### Changed
5254

5355
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))

pytorch_lightning/utilities/xla_device_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import functools
1515
import importlib
16+
import queue as q
1617
from multiprocessing import Process, Queue
1718

1819
import torch
@@ -24,10 +25,10 @@
2425
xm = None
2526

2627

27-
def inner_f(queue, func, **kwargs): # pragma: no cover
28+
def inner_f(queue, func, *args, **kwargs): # pragma: no cover
2829
try:
29-
queue.put(func(**kwargs))
30-
except Exception as _e:
30+
queue.put(func(*args, **kwargs))
31+
except Exception:
3132
import traceback
3233

3334
traceback.print_exc()
@@ -38,10 +39,13 @@ def pl_multi_process(func):
3839
@functools.wraps(func)
3940
def wrapper(*args, **kwargs):
4041
queue = Queue()
41-
proc = Process(target=inner_f, args=(queue, func,), kwargs=kwargs)
42+
proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs)
4243
proc.start()
43-
proc.join()
44-
return queue.get()
44+
proc.join(10)
45+
try:
46+
return queue.get_nowait()
47+
except q.Empty:
48+
return False
4549

4650
return wrapper
4751

tests/utilities/test_xla_device_utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import time
15+
1416
import pytest
1517

16-
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils as xdu
18+
import pytorch_lightning.utilities.xla_device_utils as xla_utils
1719
from tests.base.develop_utils import pl_multi_process_test
1820

1921
try:
2022
import torch_xla.core.xla_model as xm
23+
2124
XLA_AVAILABLE = True
2225
except ImportError as e:
2326
XLA_AVAILABLE = False
@@ -26,13 +29,13 @@
2629
@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
2730
def test_tpu_device_absence():
2831
"""Check tpu_device_exists returns None when torch_xla is not available"""
29-
assert xdu.tpu_device_exists() is None
32+
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
3033

3134

3235
@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed")
3336
def test_tpu_device_presence():
3437
"""Check tpu_device_exists returns True when TPU is available"""
35-
assert xdu.tpu_device_exists() is True
38+
assert xla_utils.XLADeviceUtils.tpu_device_exists() is True
3639

3740

3841
@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed")
@@ -42,3 +45,14 @@ def test_xla_device_is_a_tpu():
4245
device = xm.xla_device()
4346
device_type = xm.xla_device_hw(device)
4447
return device_type == "TPU"
48+
49+
50+
def test_result_returns_within_10_seconds():
51+
"""Check that pl_multi_process returns within 10 seconds"""
52+
53+
start = time.time()
54+
result = xla_utils.pl_multi_process(time.sleep)(25)
55+
end = time.time()
56+
elapsed_time = int(end - start)
57+
assert elapsed_time <= 10
58+
assert result is False

0 commit comments

Comments
 (0)