Skip to content

Commit c63202e

Browse files
kaushikb11lexierule
authored andcommitted
Move parameter validation specific to TPU Training plugins (#7415)
* Move parameter validation specific to TPU Training plugins * update docstring
1 parent 9c12662 commit c63202e

File tree

5 files changed

+22
-29
lines changed

5 files changed

+22
-29
lines changed

pytorch_lightning/core/decorators.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,11 @@ def auto_transfer_args(self, *args, **kwargs):
7171

7272
def parameter_validation(fn: Callable) -> Callable:
7373
"""
74-
Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method.
7574
Validates that the module parameter lengths match after moving to the device. It is useful
7675
when tying weights on TPU's.
7776
7877
Args:
79-
fn: ``.to`` method
78+
fn: ``model_to_device`` method
8079
8180
Note:
8281
TPU's require weights to be tied/shared after moving the module to the device.
@@ -90,10 +89,10 @@ def parameter_validation(fn: Callable) -> Callable:
9089

9190
@wraps(fn)
9291
def inner_fn(self, *args, **kwargs):
93-
pre_layer_count = len(list(self.parameters()))
92+
pre_layer_count = len(list(self.model.parameters()))
9493
module = fn(self, *args, **kwargs)
95-
self.on_post_move_to_device()
96-
post_layer_count = len(list(self.parameters()))
94+
self.model.on_post_move_to_device()
95+
post_layer_count = len(list(self.model.parameters()))
9796

9897
if not pre_layer_count == post_layer_count:
9998
rank_zero_warn(

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import torch
1717

18+
from pytorch_lightning.core.decorators import parameter_validation
1819
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
1920
from pytorch_lightning.utilities import _TPU_AVAILABLE
2021
from pytorch_lightning.utilities.apply_func import move_data_to_device
@@ -43,6 +44,7 @@ def on_tpu(self) -> bool:
4344
def is_distributed(self) -> bool:
4445
return False
4546

47+
@parameter_validation
4648
def model_to_device(self) -> None:
4749
self.model.to(self.root_device)
4850

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch.utils.data import DataLoader
2424

2525
import pytorch_lightning as pl
26+
from pytorch_lightning.core.decorators import parameter_validation
2627
from pytorch_lightning.overrides import LightningDistributedModule
2728
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
2829
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
@@ -171,6 +172,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
171172
if self.global_rank == 0:
172173
time.sleep(2)
173174

175+
@parameter_validation
174176
def model_to_device(self) -> None:
175177
self.device = xm.xla_device()
176178
self.model = self.wrapped_model.to(self.device)

pytorch_lightning/utilities/device_dtype_mixin.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import torch
1818
from torch.nn import Module
1919

20-
from pytorch_lightning.core.decorators import parameter_validation
21-
2220

2321
class DeviceDtypeModuleMixin(Module):
2422
__jit_unused_properties__ = ['device', 'dtype']
@@ -47,7 +45,6 @@ def device(self) -> Union[str, torch.device]:
4745

4846
return device
4947

50-
@parameter_validation
5148
def to(self, *args, **kwargs) -> Module:
5249
"""Moves and/or casts the parameters and buffers.
5350
@@ -84,9 +81,6 @@ def to(self, *args, **kwargs) -> Module:
8481
... def __init__(self, weight: torch.Tensor):
8582
... super().__init__()
8683
... self.register_buffer('weight', weight)
87-
...
88-
... def on_post_move_to_device(self):
89-
... pass
9084
>>> _ = torch.manual_seed(0)
9185
>>> module = ExampleModule(torch.rand(3, 4))
9286
>>> module.weight #doctest: +ELLIPSIS

tests/accelerators/test_tpu_backend.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,25 +95,21 @@ def test_weight_tying_warning(tmpdir, capsys=None):
9595
trainer.fit(model)
9696

9797

98-
# @RunIf(tpu=True)
99-
# @pl_multi_process_test
100-
# def test_if_weights_tied(tmpdir, capsys=None):
101-
# """
102-
# Test if weights are properly tied on `on_post_move_to_device`.
103-
# Ensure no warning for parameter mismatch is thrown.
104-
# """
105-
106-
# # TODO (kaushikb11): Add `parameter_validation` specific to TPU Accelerators
107-
# class Model(WeightSharingModule):
98+
@RunIf(tpu=True)
99+
@pl_multi_process_test
100+
def test_if_weights_tied(tmpdir, capsys=None):
101+
"""
102+
Test if weights are properly tied on `on_post_move_to_device`.
103+
Ensure no warning for parameter mismatch is thrown.
104+
"""
108105

109-
# def on_post_move_to_device(self):
110-
# self.layer_3.weight = self.layer_1.weight
106+
class Model(WeightSharingModule):
111107

112-
# model = Model()
113-
# trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
108+
def on_post_move_to_device(self):
109+
self.layer_3.weight = self.layer_1.weight
114110

115-
# with pytest.warns(UserWarning) as warnings:
116-
# trainer.fit(model)
111+
model = Model()
112+
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
117113

118-
# assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
119-
# assert len(trainer.test(model)) == 1
114+
with pytest.warns(UserWarning, match="The model layers do not match"):
115+
trainer.fit(model)

0 commit comments

Comments
 (0)