Skip to content

Commit 04f4448

Browse files
authored
Remove the deprecated GPUAccelerator (#16050)
1 parent 8d3339a commit 04f4448

File tree

7 files changed

+46
-43
lines changed

7 files changed

+46
-43
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9292
- Removed the deprecated `LightningDeepSpeedModule` ([#16041](https://github.com/Lightning-AI/lightning/pull/16041))
9393

9494

95+
- Removed the deprecated `pytorch_lightning.accelerators.GPUAccelerator` in favor of `pytorch_lightning.accelerators.CUDAAccelerator` ([#16050](https://github.com/Lightning-AI/lightning/pull/16050))
96+
97+
98+
9599
### Fixed
96100

97101
- Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253))

src/pytorch_lightning/_graveyard/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pytorch_lightning._graveyard.accelerator
1516
import pytorch_lightning._graveyard.callbacks
1617
import pytorch_lightning._graveyard.core
1718
import pytorch_lightning._graveyard.legacy_import_unpickler
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import sys
2+
from typing import Any
3+
4+
import pytorch_lightning as pl
5+
6+
7+
def _patch_sys_modules() -> None:
8+
# TODO: Remove in v2.0.0
9+
self = sys.modules[__name__]
10+
sys.modules["pytorch_lightning.accelerators.gpu"] = self
11+
12+
13+
class GPUAccelerator:
14+
# TODO: Remove in v2.0.0
15+
def __init__(self, *_: Any, **__: Any) -> None:
16+
raise NotImplementedError(
17+
"`pytorch_lightning.accelerators.gpu.GPUAccelerator` was deprecated in v1.7.0 and is no"
18+
" longer supported as of v1.9.0. Please use `pytorch_lightning.accelerators.CUDAAccelerator` instead"
19+
)
20+
21+
22+
def _patch_classes() -> None:
23+
# TODO: Remove in v2.0.0
24+
setattr(pl.accelerators, "GPUAccelerator", GPUAccelerator)
25+
26+
27+
_patch_sys_modules()
28+
_patch_classes()

src/pytorch_lightning/accelerators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401
1515
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
1616
from pytorch_lightning.accelerators.cuda import CUDAAccelerator # noqa: F401
17-
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
1817
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
1918
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
2019
from pytorch_lightning.accelerators.mps import MPSAccelerator # noqa: F401

src/pytorch_lightning/accelerators/gpu.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

tests/tests_pytorch/deprecated_api/test_remove_1-9.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import pytorch_lightning.loggers.base as logger_base
2121
import pytorch_lightning.utilities.cli as old_cli
2222
from pytorch_lightning import Trainer
23-
from pytorch_lightning.accelerators.gpu import GPUAccelerator
2423
from pytorch_lightning.cli import LightningCLI, SaveConfigCallback
2524
from pytorch_lightning.core.module import LightningModule
2625
from pytorch_lightning.demos.boring_classes import BoringModel
@@ -207,13 +206,3 @@ def test_pytorch_profiler_schedule_wrapper_deprecation_warning():
207206
def test_pytorch_profiler_register_record_function_deprecation_warning():
208207
with pytest.deprecated_call(match="RegisterRecordFunction` is deprecated in v1.7 and will be removed in in v1.9."):
209208
_ = RegisterRecordFunction(None)
210-
211-
212-
def test_gpu_accelerator_deprecation_warning():
213-
with pytest.deprecated_call(
214-
match=(
215-
"The `GPUAccelerator` has been renamed to `CUDAAccelerator` and will be removed in v1.9."
216-
+ " Please use the `CUDAAccelerator` instead!"
217-
)
218-
):
219-
GPUAccelerator()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
3+
4+
def test_removed_gpuaccelerator():
5+
from pytorch_lightning.accelerators.gpu import GPUAccelerator
6+
7+
with pytest.raises(NotImplementedError, match="GPUAccelerator`.*no longer supported as of v1.9"):
8+
GPUAccelerator()
9+
10+
from pytorch_lightning.accelerators import GPUAccelerator
11+
12+
with pytest.raises(NotImplementedError, match="GPUAccelerator`.*no longer supported as of v1.9"):
13+
GPUAccelerator()

0 commit comments

Comments
 (0)