Skip to content

Commit f94faa9

Browse files
kaushikb11rohitgr7awaelchlipre-commit-ci[bot]
authored
Enable auto parameters tying for TPUs (#9525)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 86ad941 commit f94faa9

File tree

11 files changed

+284
-41
lines changed

11 files changed

+284
-41
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
169169
- Added support for `torch.use_deterministic_algorithms` ([#9121](https://github.com/PyTorchLightning/pytorch-lightning/pull/9121))
170170

171171

172+
- Enabled automatic parameters tying for TPUs ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))
173+
174+
172175
### Changed
173176

174177
- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
@@ -299,6 +302,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
299302
- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677))
300303

301304

305+
- Deprecated the `LightningModule.on_post_move_to_device` method ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))
306+
307+
308+
- Deprecated `pytorch_lightning.core.decorators.parameter_validation` in favor of `pytorch_lightning.utilities.parameter_tying.set_shared_parameters` ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))
309+
310+
302311
### Removed
303312

304313
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ module = [
8080
"pytorch_lightning.utilities.distributed",
8181
"pytorch_lightning.utilities.memory",
8282
"pytorch_lightning.utilities.model_summary",
83+
"pytorch_lightning.utilities.parameter_tying",
8384
"pytorch_lightning.utilities.parsing",
8485
"pytorch_lightning.utilities.seed",
8586
"pytorch_lightning.utilities.xla_device",

pytorch_lightning/core/decorators.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Decorator for LightningModule methods."""
14+
from pytorch_lightning.utilities import rank_zero_deprecation
1515

16-
from functools import wraps
17-
from typing import Callable
16+
rank_zero_deprecation(
17+
"Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, "
18+
"and will be removed in v1.7. It has been replaced by automatic parameters tying with "
19+
"`pytorch_lightning.utilities.params_tying.set_shared_parameters`"
20+
)
1821

19-
from pytorch_lightning.utilities import rank_zero_warn
22+
from functools import wraps # noqa: E402
23+
from typing import Callable # noqa: E402
24+
25+
from pytorch_lightning.utilities import rank_zero_warn # noqa: E402
2026

2127

2228
def parameter_validation(fn: Callable) -> Callable:

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414
import os
1515
from typing import Any, Dict
1616

17-
from pytorch_lightning.core.decorators import parameter_validation
1817
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
1918
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
20-
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
19+
from pytorch_lightning.utilities import (
20+
_OMEGACONF_AVAILABLE,
21+
_TPU_AVAILABLE,
22+
find_shared_parameters,
23+
set_shared_parameters,
24+
)
2125
from pytorch_lightning.utilities.apply_func import apply_to_collection
2226
from pytorch_lightning.utilities.exceptions import MisconfigurationException
27+
from pytorch_lightning.utilities.model_helpers import is_overridden
2328
from pytorch_lightning.utilities.types import _PATH
2429

2530
if _TPU_AVAILABLE:
@@ -49,7 +54,14 @@ def __init__(
4954
def is_distributed(self) -> bool:
5055
return False
5156

52-
@parameter_validation
57+
def setup(self) -> None:
58+
shared_params = find_shared_parameters(self.model)
59+
self.model_to_device()
60+
if is_overridden("on_post_move_to_device", self.lightning_module):
61+
self.model.on_post_move_to_device()
62+
else:
63+
set_shared_parameters(self.model, shared_params)
64+
5365
def model_to_device(self) -> None:
5466
self.model.to(self.root_device)
5567

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,23 @@
2323
from torch.utils.data import DataLoader
2424

2525
import pytorch_lightning as pl
26-
from pytorch_lightning.core.decorators import parameter_validation
2726
from pytorch_lightning.overrides import LightningDistributedModule
2827
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2928
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
3029
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
3130
from pytorch_lightning.trainer.states import TrainerFn
32-
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
31+
from pytorch_lightning.utilities import (
32+
_OMEGACONF_AVAILABLE,
33+
_TPU_AVAILABLE,
34+
find_shared_parameters,
35+
rank_zero_warn,
36+
set_shared_parameters,
37+
)
3338
from pytorch_lightning.utilities.apply_func import apply_to_collection
3439
from pytorch_lightning.utilities.data import has_len
3540
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
3641
from pytorch_lightning.utilities.exceptions import MisconfigurationException
42+
from pytorch_lightning.utilities.model_helpers import is_overridden
3743
from pytorch_lightning.utilities.seed import reset_seed
3844
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
3945

@@ -156,7 +162,13 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
156162
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
157163
trainer.progress_bar_callback.disable()
158164

165+
shared_params = find_shared_parameters(self.model)
159166
self.model_to_device()
167+
if is_overridden("on_post_move_to_device", self.lightning_module):
168+
self.model.module.on_post_move_to_device()
169+
else:
170+
set_shared_parameters(self.model.module, shared_params)
171+
160172
trainer.accelerator.setup_optimizers(trainer)
161173
trainer.precision_plugin.connect(self._model, None, None)
162174

@@ -176,7 +188,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
176188
# ensure that spawned processes go through teardown before joining
177189
trainer._call_teardown_hook()
178190

179-
@parameter_validation
180191
def model_to_device(self) -> None:
181192
self.model = self.wrapped_model.to(self.root_device)
182193

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
4646
self._check_add_get_queue(model)
4747
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
4848
self._check_progress_bar(model)
49+
# TODO: Delete _check_on_post_move_to_device in v1.7
50+
self._check_on_post_move_to_device(model)
4951
# TODO: Delete _check_on_keyboard_interrupt in v1.7
5052
self._check_on_keyboard_interrupt()
5153

@@ -127,6 +129,19 @@ def _check_progress_bar(self, model: "pl.LightningModule") -> None:
127129
" Please use the `ProgressBarBase.get_metrics` instead."
128130
)
129131

132+
def _check_on_post_move_to_device(self, model: "pl.LightningModule") -> None:
133+
r"""
134+
Checks if `on_post_move_to_device` method is overriden and sends a deprecation warning.
135+
136+
Args:
137+
model: The model to check the `on_post_move_to_device` method.
138+
"""
139+
if is_overridden("on_post_move_to_device", model):
140+
rank_zero_deprecation(
141+
"Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7. "
142+
"We perform automatic parameters tying without the need of implementing `on_post_move_to_device`."
143+
)
144+
130145
def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: str) -> None:
131146
loader_name = f"{stage}_dataloader"
132147
step_name = "validation_step" if stage == "val" else "test_step"

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
_TPU_AVAILABLE,
5858
_XLA_AVAILABLE,
5959
)
60+
from pytorch_lightning.utilities.parameter_tying import find_shared_parameters, set_shared_parameters # noqa: F401
6061
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
6162
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn # noqa: F401
6263

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Utilities for automatic parameters tying.
15+
16+
Reference:
17+
https://github.com/pytorch/fairseq/blob/1f7ef9ed1e1061f8c7f88f8b94c7186834398690/fairseq/trainer.py#L110-L118
18+
"""
19+
from typing import Dict, List, Optional
20+
21+
from torch import nn
22+
from torch.nn import Parameter
23+
24+
25+
def find_shared_parameters(module: nn.Module) -> List[str]:
26+
"""Returns a list of names of shared parameters set in the module."""
27+
return _find_shared_parameters(module)
28+
29+
30+
def _find_shared_parameters(module: nn.Module, tied_parameters: Optional[Dict] = None, prefix: str = "") -> List[str]:
31+
if tied_parameters is None:
32+
first_call = True
33+
tied_parameters = {}
34+
else:
35+
first_call = False
36+
for name, param in module._parameters.items():
37+
param_prefix = prefix + ("." if prefix else "") + name
38+
if param is None:
39+
continue
40+
if param not in tied_parameters:
41+
tied_parameters[param] = []
42+
tied_parameters[param].append(param_prefix)
43+
for name, m in module._modules.items():
44+
if m is None:
45+
continue
46+
submodule_prefix = prefix + ("." if prefix else "") + name
47+
_find_shared_parameters(m, tied_parameters, submodule_prefix)
48+
if first_call:
49+
return [x for x in tied_parameters.values() if len(x) > 1]
50+
51+
52+
def set_shared_parameters(module: nn.Module, shared_params: list) -> nn.Module:
53+
for shared_param in shared_params:
54+
ref = _get_module_by_path(module, shared_param[0])
55+
for path in shared_param[1:]:
56+
_set_module_by_path(module, path, ref)
57+
return module
58+
59+
60+
def _get_module_by_path(module: nn.Module, path: str) -> nn.Module:
61+
path = path.split(".")
62+
for name in path:
63+
module = getattr(module, name)
64+
return module
65+
66+
67+
def _set_module_by_path(module: nn.Module, path: str, value: Parameter) -> None:
68+
path = path.split(".")
69+
for name in path[:-1]:
70+
module = getattr(module, name)
71+
setattr(module, path[-1], value)

tests/accelerators/test_tpu_backend.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.accelerators.tpu import TPUAccelerator
2424
from pytorch_lightning.callbacks import Callback
2525
from pytorch_lightning.plugins import TPUSpawnPlugin
26+
from pytorch_lightning.utilities import find_shared_parameters
2627
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2728
from tests.helpers.boring_model import BoringModel
2829
from tests.helpers.runif import RunIf
@@ -80,37 +81,6 @@ def test_if_test_works_after_train(tmpdir):
8081
assert len(trainer.test(model)) == 1
8182

8283

83-
@RunIf(tpu=True)
84-
@pl_multi_process_test
85-
def test_weight_tying_warning(tmpdir, capsys=None):
86-
"""Ensure a warning is thrown if model parameter lengths do not match post moving to device."""
87-
88-
model = WeightSharingModule()
89-
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
90-
91-
with pytest.warns(UserWarning, match=r"The model layers do not match after moving to the target device."):
92-
trainer.fit(model)
93-
94-
95-
@RunIf(tpu=True)
96-
@pl_multi_process_test
97-
def test_if_weights_tied(tmpdir, capsys=None):
98-
"""Test if weights are properly tied on `on_post_move_to_device`.
99-
100-
Ensure no warning for parameter mismatch is thrown.
101-
"""
102-
103-
class Model(WeightSharingModule):
104-
def on_post_move_to_device(self):
105-
self.layer_3.weight = self.layer_1.weight
106-
107-
model = Model()
108-
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
109-
110-
with pytest.warns(UserWarning, match="The model layers do not match"):
111-
trainer.fit(model)
112-
113-
11484
@RunIf(tpu=True)
11585
def test_accelerator_tpu():
11686

@@ -257,3 +227,49 @@ def test_ddp_cpu_not_supported_on_tpus():
257227

258228
with pytest.raises(MisconfigurationException, match="`accelerator='ddp_cpu'` is not supported on TPU machines"):
259229
Trainer(accelerator="ddp_cpu")
230+
231+
232+
@RunIf(tpu=True)
233+
def test_auto_parameters_tying_tpus(tmpdir):
234+
235+
model = WeightSharingModule()
236+
shared_params = find_shared_parameters(model)
237+
238+
assert shared_params[0] == ["layer_1.weight", "layer_3.weight"]
239+
240+
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, tpu_cores=8, max_epochs=1)
241+
trainer.fit(model)
242+
243+
assert torch.all(torch.eq(model.layer_1.weight, model.layer_3.weight))
244+
245+
246+
@RunIf(tpu=True)
247+
def test_auto_parameters_tying_tpus_nested_module(tmpdir):
248+
class SubModule(nn.Module):
249+
def __init__(self, layer):
250+
super().__init__()
251+
self.layer = layer
252+
253+
def forward(self, x):
254+
return self.layer(x)
255+
256+
class NestedModule(BoringModel):
257+
def __init__(self):
258+
super().__init__()
259+
self.layer = nn.Linear(32, 10, bias=False)
260+
self.net_a = SubModule(self.layer)
261+
self.layer_2 = nn.Linear(10, 32, bias=False)
262+
self.net_b = SubModule(self.layer)
263+
264+
def forward(self, x):
265+
x = self.net_a(x)
266+
x = self.layer_2(x)
267+
x = self.net_b(x)
268+
return x
269+
270+
model = NestedModule()
271+
272+
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, tpu_cores=8, max_epochs=1)
273+
trainer.fit(model)
274+
275+
assert torch.all(torch.eq(model.net_a.layer.weight, model.net_b.layer.weight))

tests/deprecated_api/test_remove_1-7.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,27 @@ def test_v1_7_0_deprecate_lightning_distributed(tmpdir):
255255
from pytorch_lightning.distributed.dist import LightningDistributed
256256

257257
_ = LightningDistributed()
258+
259+
260+
def test_v1_7_0_deprecate_on_post_move_to_device(tmpdir):
261+
class TestModel(BoringModel):
262+
def on_post_move_to_device(self):
263+
print("on_post_move_to_device")
264+
265+
model = TestModel()
266+
267+
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, max_epochs=1)
268+
269+
with pytest.deprecated_call(
270+
match=r"Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7"
271+
):
272+
trainer.fit(model)
273+
274+
275+
def test_v1_7_0_deprecate_parameter_validation():
276+
277+
_soft_unimport_module("pytorch_lightning.core.decorators")
278+
with pytest.deprecated_call(
279+
match="Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5"
280+
):
281+
from pytorch_lightning.core.decorators import parameter_validation # noqa: F401

0 commit comments

Comments
 (0)