Skip to content

Commit d243617

Browse files
Provide access to unwrapped model in Lite (#12597)
Co-authored-by: Akihiro Nitta <[email protected]>
1 parent 4011f37 commit d243617

File tree

5 files changed

+80
-16
lines changed

5 files changed

+80
-16
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
197197
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))
198198

199199

200-
- Avoid redundant callback restore warning while tuning ([#13026](https://github.com/PyTorchLightning/pytorch-lightning/pull/13026))
200+
- The model wrapper returned by `LightningLite.setup()` now properly supports pass-through when looking up attributes ([#12597](https://github.com/PyTorchLightning/pytorch-lightning/pull/12597))
201201

202202

203-
-
203+
- Avoid redundant callback restore warning while tuning ([#13026](https://github.com/PyTorchLightning/pytorch-lightning/pull/13026))
204204

205205

206206
-

pytorch_lightning/lite/lite.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,25 +152,26 @@ def setup(
152152
*optimizers: Optimizer,
153153
move_to_device: bool = True,
154154
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
155-
"""Setup a model and its optimizers for accelerated training.
155+
"""Set up a model and its optimizers for accelerated training.
156156
157157
Args:
158-
model: A model to setup
159-
*optimizers: The optimizer(s) to setup (no optimizers is also possible)
158+
model: A model to set up
159+
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
160160
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
161161
and alternatively use :meth:`to_device` manually.
162162
163163
Returns:
164164
The tuple of the wrapped model and list of optimizers, in the same order they were passed in.
165165
"""
166166
self._validate_setup(model, optimizers)
167+
original_model = model
167168

168169
if move_to_device:
169170
model = self._move_model_to_device(model=model, optimizers=list(optimizers))
170171

171172
# Let accelerator/plugin wrap and connect the models and optimizers
172173
model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers))
173-
model = _LiteModule(model, self._precision_plugin)
174+
model = _LiteModule(model, self._precision_plugin, original_module=original_model)
174175
optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
175176
self._models_setup += 1
176177
if optimizers:
@@ -181,7 +182,7 @@ def setup(
181182
def setup_dataloaders(
182183
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
183184
) -> Union[DataLoader, List[DataLoader]]:
184-
"""Setup one or multiple dataloaders for accelerated training. If you need different settings for each
185+
"""Set up one or multiple dataloaders for accelerated training. If you need different settings for each
185186
dataloader, call this method individually for each one.
186187
187188
Args:
@@ -206,7 +207,7 @@ def setup_dataloaders(
206207
def _setup_dataloader(
207208
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
208209
) -> DataLoader:
209-
"""Setup a single dataloader for accelerated training.
210+
"""Set up a single dataloader for accelerated training.
210211
211212
Args:
212213
dataloader: The dataloader to accelerate.
@@ -252,10 +253,10 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No
252253
**kwargs: Optional named keyword arguments passed to the underlying backward function.
253254
254255
Note:
255-
When using ``strategy="deepspeed"`` and multiple models were setup, it is required to pass in the
256+
When using ``strategy="deepspeed"`` and multiple models were set up, it is required to pass in the
256257
model as argument here.
257258
"""
258-
module = model.module if model is not None else model
259+
module = model._forward_module if model is not None else model
259260
if isinstance(self._strategy, DeepSpeedStrategy):
260261
if model is None:
261262
if self._models_setup == 0:

pytorch_lightning/lite/wrappers.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,29 @@ def step(self, closure: Optional[Callable] = None) -> Any:
6565

6666

6767
class _LiteModule(DeviceDtypeModuleMixin):
68-
def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None:
68+
def __init__(
69+
self, forward_module: nn.Module, precision_plugin: PrecisionPlugin, original_module: Optional[nn.Module] = None
70+
) -> None:
6971
"""The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
7072
automatically for the forward pass.
7173
7274
The underlying wrapped module can be accessed via the property :attr:`module`.
7375
7476
Args:
75-
module: The module to wrap
77+
forward_module: The module to wrap the ``forward`` method on.
7678
precision_plugin: Reference to the precision plugin for handling precision context
79+
original_module: The original, unmodified module as passed into the
80+
:meth:`pytorch_lightning.lite.lite.LightningLite.setup` method. This is needed when attribute lookup
81+
on this wrapper should pass through to the original module.
7782
"""
7883
super().__init__()
79-
self._module = module
84+
self._forward_module = forward_module
85+
self._original_module = original_module or forward_module
8086
self._precision_plugin = precision_plugin
8187

8288
@property
8389
def module(self) -> nn.Module:
84-
return self._module
90+
return self._original_module or self._forward_module
8591

8692
def forward(self, *args: Any, **kwargs: Any) -> Any:
8793
"""Casts all inputs to the right precision and handles autocast for operations in the module forward
@@ -102,12 +108,22 @@ def _convert_float_tensor(t: Tensor) -> Tensor:
102108
args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor)
103109

104110
with self._precision_plugin.forward_context():
105-
output = self.module(*args, **kwargs)
111+
output = self._forward_module(*args, **kwargs)
106112

107113
to_type = torch.get_default_dtype()
108114
output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor)
109115
return output
110116

117+
def __getattr__(self, item: Any) -> Any:
118+
try:
119+
# __getattr__ gets called as a last resort if the attribute does not exist
120+
# call nn.Module's implementation first
121+
return super().__getattr__(item)
122+
except AttributeError:
123+
# If the attribute is not available on the _LiteModule wrapper, redirect to the wrapped nn.Module
124+
original_module = super().__getattr__("_original_module")
125+
return getattr(original_module, item)
126+
111127

112128
class _LiteDataLoader:
113129
def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:

tests/lite/test_lite.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
from copy import deepcopy
1616
from unittest import mock
17-
from unittest.mock import MagicMock, Mock, PropertyMock
17+
from unittest.mock import ANY, MagicMock, Mock, PropertyMock
1818

1919
import pytest
2020
import torch
@@ -80,6 +80,18 @@ def run(self, *args, **kwargs):
8080
assert lite.run_kwargs == {"three": 3}
8181

8282

83+
@mock.patch("pytorch_lightning.strategies.ddp.DistributedDataParallel")
84+
def test_setup_model(ddp_mock):
85+
"""Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model."""
86+
lite = EmptyLite(accelerator="cpu", strategy="ddp", devices=2)
87+
model = nn.Linear(1, 2)
88+
lite_model = lite.setup(model)
89+
ddp_mock.assert_called_with(module=model, device_ids=ANY)
90+
assert lite_model.module == model
91+
assert lite_model.weight is model.weight
92+
assert lite_model.forward != model.forward
93+
94+
8395
def test_setup_optimizers():
8496
"""Test that setup_optimizers can handle no optimizers, one optimizer, or multiple optimizers."""
8597
lite = EmptyLite()

tests/lite/test_wrappers.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,41 @@ def test_lite_module_wraps():
3333
module = Mock()
3434
assert _LiteModule(module, Mock()).module is module
3535

36+
wrapped_module = Mock()
37+
original_module = Mock()
38+
assert _LiteModule(wrapped_module, Mock(), original_module=original_module).module is original_module
39+
40+
41+
def test_lite_module_attribute_lookup():
42+
"""Test that attribute lookup passes through to the original model when possible."""
43+
44+
class OriginalModule(torch.nn.Module):
45+
def __init__(self):
46+
super().__init__()
47+
self.layer = torch.nn.Linear(2, 3)
48+
self.attribute = 1
49+
50+
def method(self):
51+
return 2
52+
53+
original_module = OriginalModule()
54+
55+
class ModuleWrapper(torch.nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
self.wrapped = original_module
59+
60+
wrapped_module = ModuleWrapper()
61+
62+
lite_module = _LiteModule(wrapped_module, Mock(), original_module=original_module)
63+
assert lite_module.attribute == 1
64+
assert lite_module.layer is original_module.layer
65+
assert lite_module.method() == 2
66+
assert lite_module.forward.__self__.__class__ == _LiteModule
67+
68+
with pytest.raises(AttributeError):
69+
_ = lite_module.not_exists
70+
3671

3772
@RunIf(min_gpus=1)
3873
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)