diff --git a/pyproject.toml b/pyproject.toml index 0da4b2cfe10b4..21a96f47db56e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,5 +70,6 @@ module = [ "pytorch_lightning.utilities.device_parser", "pytorch_lightning.utilities.distributed", "pytorch_lightning.utilities.parsing", + "pytorch_lightning.utilities.xla_device", ] ignore_errors = "False" diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index e8ecd23ef55f0..5e8cdd21378bb 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -89,7 +89,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCHVISION_AVAILABLE = _module_available("torchvision") _TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0") _TORCHMETRICS_GREATER_EQUAL_0_3 = _compare_version("torchmetrics", operator.ge, "0.3.0") -_XLA_AVAILABLE = _module_available("torch_xla") +_XLA_AVAILABLE: bool = _module_available("torch_xla") from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402 diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 046c22541750c..3a9073cfa122a 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -16,6 +16,7 @@ import queue as q import traceback from multiprocessing import Process, Queue +from typing import Any, Callable, Union from pytorch_lightning.utilities.imports import _XLA_AVAILABLE @@ -26,7 +27,7 @@ TPU_CHECK_TIMEOUT = 60 -def inner_f(queue, func, *args, **kwargs): # pragma: no cover +def inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover try: queue.put(func(*args, **kwargs)) # todo: specify the possible exception @@ -35,10 +36,10 @@ def inner_f(queue, func, *args, **kwargs): # pragma: no cover queue.put(None) -def pl_multi_process(func): +def pl_multi_process(func: Callable) -> Callable: @functools.wraps(func) - def wrapper(*args, **kwargs): - queue = Queue() + def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]: + queue: Queue = Queue() proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) proc.start() proc.join(TPU_CHECK_TIMEOUT)