Skip to content

Commit f77860c

Browse files
stancldfour4fish
authored andcommitted
Fix mypy for utilities.xla_device (Lightning-AI#8755)
* Fix mypy for utilities.xla_device * Add explicit type hint for _XLA_AVAILABLE in utilities.imports
1 parent c5aff5b commit f77860c

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,6 @@ module = [
7171
"pytorch_lightning.utilities.distributed",
7272
"pytorch_lightning.utilities.memory",
7373
"pytorch_lightning.utilities.parsing",
74+
"pytorch_lightning.utilities.xla_device",
7475
]
7576
ignore_errors = "False"

pytorch_lightning/utilities/imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _compare_version(package: str, op, version) -> bool:
8989
_TORCHVISION_AVAILABLE = _module_available("torchvision")
9090
_TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0")
9191
_TORCHMETRICS_GREATER_EQUAL_0_3 = _compare_version("torchmetrics", operator.ge, "0.3.0")
92-
_XLA_AVAILABLE = _module_available("torch_xla")
92+
_XLA_AVAILABLE: bool = _module_available("torch_xla")
9393

9494
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402
9595

pytorch_lightning/utilities/xla_device.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import queue as q
1717
import traceback
1818
from multiprocessing import Process, Queue
19+
from typing import Any, Callable, Union
1920

2021
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
2122

@@ -26,7 +27,7 @@
2627
TPU_CHECK_TIMEOUT = 60
2728

2829

29-
def inner_f(queue, func, *args, **kwargs): # pragma: no cover
30+
def inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover
3031
try:
3132
queue.put(func(*args, **kwargs))
3233
# todo: specify the possible exception
@@ -35,10 +36,10 @@ def inner_f(queue, func, *args, **kwargs): # pragma: no cover
3536
queue.put(None)
3637

3738

38-
def pl_multi_process(func):
39+
def pl_multi_process(func: Callable) -> Callable:
3940
@functools.wraps(func)
40-
def wrapper(*args, **kwargs):
41-
queue = Queue()
41+
def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]:
42+
queue: Queue = Queue()
4243
proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs)
4344
proc.start()
4445
proc.join(TPU_CHECK_TIMEOUT)

0 commit comments

Comments
 (0)