Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions src/lightning_lite/utilities/xla_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import queue as q
import traceback
from multiprocessing import Process, Queue
from typing import Any, Callable, Union

from pytorch_lightning.utilities.imports import _XLA_AVAILABLE

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm

# define TPU availability timeout in seconds
TPU_CHECK_TIMEOUT = 60


def inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover
try:
queue.put(func(*args, **kwargs))
# todo: specify the possible exception
except Exception:
traceback.print_exc()
queue.put(None)


def pl_multi_process(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]:
queue: Queue = Queue()
proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs)
proc.start()
proc.join(TPU_CHECK_TIMEOUT)
try:
return queue.get_nowait()
except q.Empty:
traceback.print_exc()
return False

return wrapper


class XLADeviceUtils:
"""Used to detect the type of XLA device."""

_TPU_AVAILABLE = False

@staticmethod
@pl_multi_process
def _is_device_tpu() -> bool:
"""Check if TPU devices are available.

Return:
A boolean value indicating if TPU devices are available
"""
# For the TPU Pod training process, for example, if we have
# TPU v3-32 with 4 VMs, the world size would be 4 and as
# we would have to use `torch_xla.distributed.xla_dist` for
# multiple VMs and TPU_CONFIG won't be available, running
# `xm.get_xla_supported_devices("TPU")` won't be possible.
return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU"))

@staticmethod
def xla_available() -> bool:
"""Check if XLA library is installed.

Return:
A boolean value indicating if a XLA is installed
"""
return _XLA_AVAILABLE

@staticmethod
def tpu_device_exists() -> bool:
"""Runs XLA device check within a separate process.

Return:
A boolean value indicating if a TPU device exists on the system
"""
if os.getenv("PL_TPU_AVAILABLE", "0") == "1":
XLADeviceUtils._TPU_AVAILABLE = True

if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE:

XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu()

if XLADeviceUtils._TPU_AVAILABLE:
os.environ["PL_TPU_AVAILABLE"] = "1"
return XLADeviceUtils._TPU_AVAILABLE
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the `pl.core.mixins.DeviceDtypeModuleMixin` in favor of `lightning_lite.utilities.DeviceDtypeModuleMixin` ([#14511](https://github.com/Lightning-AI/lightning/pull/14511))


- Deprecated all functions in `pytorch_lightning.utilities.xla_device` in favor of `lightning_lite.utilities.xla_device` ([#14514](https://github.com/Lightning-AI/lightning/pull/14514))



### Removed

- Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011))
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __repr__(self) -> str:
_XLA_AVAILABLE: bool = _package_available("torch_xla")


from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402
from lightning_lite.utilities.xla_device import XLADeviceUtils # noqa: E402

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

Expand Down
104 changes: 33 additions & 71 deletions src/pytorch_lightning/utilities/xla_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,90 +11,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import queue as q
import traceback
from multiprocessing import Process, Queue
from typing import Any, Callable, Union

from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
from multiprocessing import Queue
from typing import Any, Callable

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm

# define TPU availability timeout in seconds
TPU_CHECK_TIMEOUT = 60
from lightning_lite.utilities.xla_device import inner_f as new_inner_f
from lightning_lite.utilities.xla_device import pl_multi_process as new_pl_multi_process
from lightning_lite.utilities.xla_device import XLADeviceUtils as NewXLADeviceUtils
from pytorch_lightning.utilities import rank_zero_deprecation # TODO(lite): update to lightning_lite.utilities


def inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover
try:
queue.put(func(*args, **kwargs))
# todo: specify the possible exception
except Exception:
traceback.print_exc()
queue.put(None)
rank_zero_deprecation(
"`pytorch_lightning.utilities.xla_device.inner_f` has been deprecated in v1.8.0 and will be"
" removed in v1.10.0. Please use `lightning_lite.utilities.xla_device.inner_f` instead."
)
return new_inner_f(queue, func, *args, **kwargs)


def pl_multi_process(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]:
queue: Queue = Queue()
proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs)
proc.start()
proc.join(TPU_CHECK_TIMEOUT)
try:
return queue.get_nowait()
except q.Empty:
traceback.print_exc()
return False

return wrapper

rank_zero_deprecation(
"`pytorch_lightning.utilities.xla_device.pl_multi_process` has been deprecated in v1.8.0 and will be"
" removed in v1.10.0. Please use `lightning_lite.utilities.xla_device.pl_multi_process` instead."
)
return new_pl_multi_process(func)

class XLADeviceUtils:
"""Used to detect the type of XLA device."""

_TPU_AVAILABLE = False

@staticmethod
@pl_multi_process
def _is_device_tpu() -> bool:
"""Check if TPU devices are available.

Return:
A boolean value indicating if TPU devices are available
"""
# For the TPU Pod training process, for example, if we have
# TPU v3-32 with 4 VMs, the world size would be 4 and as
# we would have to use `torch_xla.distributed.xla_dist` for
# multiple VMs and TPU_CONFIG won't be available, running
# `xm.get_xla_supported_devices("TPU")` won't be possible.
return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU"))
class XLADeviceUtils(NewXLADeviceUtils):
def __init__(self) -> None:
rank_zero_deprecation(
"`pytorch_lightning.utilities.xla_device.XLADeviceUtils` has been deprecated in v1.8.0 and will be"
" removed in v1.10.0. Please use `lightning_lite.utilities.xla_device.XLADeviceUtils` instead."
)
super().__init__()

@staticmethod
def xla_available() -> bool:
"""Check if XLA library is installed.

Return:
A boolean value indicating if a XLA is installed
"""
return _XLA_AVAILABLE
rank_zero_deprecation(
"`pytorch_lightning.utilities.xla_device.XLADeviceUtils` has been deprecated in v1.8.0 and will be"
" removed in v1.10.0. Please use `lightning_lite.utilities.xla_device.XLADeviceUtils` instead."
)
return NewXLADeviceUtils.xla_available()

@staticmethod
def tpu_device_exists() -> bool:
"""Runs XLA device check within a separate process.

Return:
A boolean value indicating if a TPU device exists on the system
"""
if os.getenv("PL_TPU_AVAILABLE", "0") == "1":
XLADeviceUtils._TPU_AVAILABLE = True

if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE:

XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu()

if XLADeviceUtils._TPU_AVAILABLE:
os.environ["PL_TPU_AVAILABLE"] = "1"
return XLADeviceUtils._TPU_AVAILABLE
rank_zero_deprecation(
"`pytorch_lightning.utilities.xla_device.XLADeviceUtils` has been deprecated in v1.8.0 and will be"
" removed in v1.10.0. Please use `lightning_lite.utilities.xla_device.XLADeviceUtils` instead."
)
return NewXLADeviceUtils.tpu_device_exists()
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest

import pytorch_lightning.utilities.xla_device as xla_utils
import lightning_lite.utilities.xla_device as xla_utils
from pytorch_lightning.utilities import _XLA_AVAILABLE
from tests_pytorch.helpers.runif import RunIf

Expand Down
20 changes: 20 additions & 0 deletions tests/tests_pytorch/deprecated_api/test_remove_1-10.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.10.0."""
from unittest import mock

import pytest

from pytorch_lightning import Trainer
Expand All @@ -24,6 +26,7 @@
from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule
from pytorch_lightning.strategies.ipu import LightningIPUModule
from pytorch_lightning.strategies.utils import on_colab_kaggle
from pytorch_lightning.utilities.xla_device import inner_f, pl_multi_process, XLADeviceUtils
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.utils import no_warning_call

Expand Down Expand Up @@ -78,3 +81,20 @@ class MyModule(DeviceDtypeModuleMixin):

with pytest.deprecated_call(match="mixins.DeviceDtypeModuleMixin` has been deprecated in v1.8.0"):
MyModule()


def test_v1_10_deprecated_xla_device_utilities():
with pytest.deprecated_call(match="xla_device.inner_f` has been deprecated in v1.8.0"):
inner_f(mock.Mock(), mock.Mock())

with pytest.deprecated_call(match="xla_device.pl_multi_process` has been deprecated in v1.8.0"):
pl_multi_process(mock.Mock)

with pytest.deprecated_call(match="xla_device.XLADeviceUtils` has been deprecated in v1.8.0"):
XLADeviceUtils()

with pytest.deprecated_call(match="xla_device.XLADeviceUtils` has been deprecated in v1.8.0"):
XLADeviceUtils.xla_available()

with pytest.deprecated_call(match="xla_device.XLADeviceUtils` has been deprecated in v1.8.0"):
XLADeviceUtils.tpu_device_exists()