Skip to content

Commit 9f19717

Browse files
committed
Add back GPUAccelerator and deprecate it
1 parent 7080ef7 commit 9f19717

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

src/pytorch_lightning/accelerators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401
1414
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
1515
from pytorch_lightning.accelerators.cuda import CUDAAccelerator # noqa: F401
16+
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
1617
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
1718
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
1819
from pytorch_lightning.accelerators.mps import MPSAccelerator # noqa: F401
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
15+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
16+
17+
18+
class GPUAccelerator(CUDAAccelerator):
19+
"""Accelerator for NVIDIA GPU devices.
20+
21+
.. deprecated:: 1.9
22+
23+
Please use the ``CUDAAccelerator`` instead.
24+
"""
25+
26+
def __init__(self) -> None:
27+
rank_zero_deprecation(
28+
"The `GPUAccelerator` has been renamed to `CUDAAccelerator` and will be removed in v1.9."
29+
" Please use the `CUDAAccelerator` instead!"
30+
)
31+
super().__init__()

tests/tests_pytorch/deprecated_api/test_remove_1-9.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import pytorch_lightning.loggers.base as logger_base
2020
from pytorch_lightning import Trainer
21+
from pytorch_lightning.accelerators.gpu import GPUAccelerator
2122
from pytorch_lightning.core.module import LightningModule
2223
from pytorch_lightning.demos.boring_classes import BoringModel
2324
from pytorch_lightning.profiler.advanced import AdvancedProfiler
@@ -195,3 +196,13 @@ def test_pytorch_profiler_schedule_wrapper_deprecation_warning():
195196
def test_pytorch_profiler_register_record_function_deprecation_warning():
196197
with pytest.deprecated_call(match="RegisterRecordFunction` is deprecated in v1.7 and will be removed in in v1.9."):
197198
_ = RegisterRecordFunction(None)
199+
200+
201+
def test_gpu_accelerator_deprecation_warning():
202+
with pytest.deprecated_call(
203+
match=(
204+
"The `GPUAccelerator` has been renamed to `CUDAAccelerator` and will be removed in v1.9."
205+
+ " Please use the `CUDAAccelerator` instead!"
206+
)
207+
):
208+
GPUAccelerator()

0 commit comments

Comments
 (0)