Skip to content

Commit 7deac02

Browse files
awaelchlipre-commit-ci[bot]
authored andcommitted
Make LightningModule torch.jit.script-able again (#15947)
* Make LightningModule torch.jit.script-able again * remove skip Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit b5fa896)
1 parent 8542b26 commit 7deac02

File tree

5 files changed

+13
-39
lines changed

5 files changed

+13
-39
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Fixed issue with unsupported torch.inference_mode() on hpu backends ([#15918](https://github.com/Lightning-AI/lightning/pull/15918))
1616
- Fixed LRScheduler import for PyTorch 2.0 ([#15940](https://github.com/Lightning-AI/lightning/pull/15940))
1717
- Fixed `fit_loop.restarting` to be `False` for lr finder ([#15620](https://github.com/Lightning-AI/lightning/pull/15620))
18+
- Fixed `torch.jit.script`-ing a LightningModule causing an unintended error message about deprecated `use_amp` property ([#15947](https://github.com/Lightning-AI/lightning/pull/15947))
1819

1920

2021
## [1.8.3] - 2022-11-22

src/pytorch_lightning/_graveyard/core.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from typing import Any
1515

16-
from pytorch_lightning import LightningDataModule, LightningModule
16+
from pytorch_lightning import LightningDataModule
1717

1818

1919
def _on_save_checkpoint(_: LightningDataModule, __: Any) -> None:
@@ -32,28 +32,6 @@ def _on_load_checkpoint(_: LightningDataModule, __: Any) -> None:
3232
)
3333

3434

35-
def _use_amp(_: LightningModule) -> None:
36-
# Remove in v2.0.0 and the skip in `__jit_unused_properties__`
37-
if not LightningModule._jit_is_scripting:
38-
# cannot use `AttributeError` as it messes up with `nn.Module.__getattr__`
39-
raise RuntimeError(
40-
"`LightningModule.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8."
41-
" Please use `Trainer.amp_backend`.",
42-
)
43-
44-
45-
def _use_amp_setter(_: LightningModule, __: bool) -> None:
46-
# Remove in v2.0.0
47-
# cannot use `AttributeError` as it messes up with `nn.Module.__getattr__`
48-
raise RuntimeError(
49-
"`LightningModule.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8."
50-
" Please use `Trainer.amp_backend`.",
51-
)
52-
53-
54-
# Properties
55-
LightningModule.use_amp = property(fget=_use_amp, fset=_use_amp_setter)
56-
5735
# Methods
5836
LightningDataModule.on_save_checkpoint = _on_save_checkpoint
5937
LightningDataModule.on_load_checkpoint = _on_load_checkpoint

src/pytorch_lightning/core/module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ class LightningModule(
8888
"automatic_optimization",
8989
"truncated_bptt_steps",
9090
"trainer",
91-
"use_amp", # from graveyard
9291
]
9392
+ _DeviceDtypeModuleMixin.__jit_unused_properties__
9493
+ HyperparametersMixin.__jit_unused_properties__

tests/tests_pytorch/core/test_lightning_module.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,17 @@ def test_proper_refcount():
425425
assert sys.getrefcount(torch_module) == sys.getrefcount(lightning_module)
426426

427427

428+
def test_lightning_module_scriptable():
429+
"""Test that the LightningModule is `torch.jit.script`-able.
430+
431+
Regression test for #15917.
432+
"""
433+
model = BoringModel()
434+
trainer = Trainer()
435+
model.trainer = trainer
436+
torch.jit.script(model)
437+
438+
428439
def test_trainer_reference_recursively():
429440
ensemble = LightningModule()
430441
inner = LightningModule()

tests/tests_pytorch/graveyard/test_core.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,3 @@ def on_load_checkpoint(self, checkpoint):
5353
match="`LightningDataModule.on_load_checkpoint`.*no longer supported as of v1.8.",
5454
):
5555
trainer.fit(model, OnLoadDataModule())
56-
57-
58-
def test_v2_0_0_lightning_module_unsupported_use_amp():
59-
model = BoringModel()
60-
with pytest.raises(
61-
RuntimeError,
62-
match="`LightningModule.use_amp`.*no longer accessible as of v1.8.",
63-
):
64-
model.use_amp
65-
66-
with pytest.raises(
67-
RuntimeError,
68-
match="`LightningModule.use_amp`.*no longer accessible as of v1.8.",
69-
):
70-
model.use_amp = False

0 commit comments

Comments
 (0)