Skip to content

Commit 27db899

Browse files
akihironittacarmocca
authored andcommitted
Fix materialize_module recursively setting its child module (#12870)
1 parent 8c70e25 commit 27db899

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616
- Fixed an issue to use wrapped `LightningModule` for evaluation during `trainer.fit` for `BaguaStrategy` ([#12983](https://github.com/PyTorchLightning/pytorch-lightning/pull/12983))
1717
- Fixed an issue wrt unnecessary usage of habana mixed precision package for fp32 types ([#13028](https://github.com/PyTorchLightning/pytorch-lightning/pull/13028))
1818
- Fixed the number of references of `LightningModule` so it can be deleted ([#12897](https://github.com/PyTorchLightning/pytorch-lightning/pull/12897))
19+
- Fixed `materialize_module` setting a module's child recursively ([#12870](https://github.com/PyTorchLightning/pytorch-lightning/pull/12870))
1920

2021

2122
## [1.6.3] - 2022-05-03

pytorch_lightning/utilities/meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
186186
if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)):
187187
materialize_module(child)
188188
else:
189-
setattr(child, name, materialize_fn())
189+
setattr(root_module, name, materialize_fn())
190190
return root_module
191191

192192

tests/utilities/test_meta.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pytest
1415
from torch import nn
1516

1617
from pytorch_lightning.core.lightning import LightningModule
1718
from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module
19+
from tests.helpers.boring_model import BoringModel
1820
from tests.helpers.runif import RunIf
1921

2022

@@ -24,7 +26,7 @@ def __init__(self, num_layers: int):
2426
self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)] + [nn.Dropout(), nn.LayerNorm(1)])
2527

2628

27-
class BoringModel(LightningModule):
29+
class SimpleBoringModel(LightningModule):
2830
def __init__(self, num_layers: int):
2931
super().__init__()
3032
self.save_hyperparameters()
@@ -48,7 +50,7 @@ def test_init_meta_context():
4850
assert not is_on_meta_device(mlp)
4951
assert not is_on_meta_device(nn.Module())
5052

51-
model = BoringModel(4)
53+
model = SimpleBoringModel(4)
5254
assert model.layer[0].weight.device.type == "meta"
5355
materialize_module(model)
5456
assert model.layer[0].weight.device.type == "cpu"
@@ -68,3 +70,15 @@ def test_init_meta_context():
6870

6971
m = nn.Linear(in_features=1, out_features=1)
7072
assert m.weight.device.type == "cpu"
73+
74+
75+
@RunIf(min_torch="1.10.0", standalone=True)
76+
def test_materialize_module_recursive_child():
77+
"""Test materialize_module doesn't set a child recursively to a model instantiated within init_meta_context."""
78+
with init_meta_context():
79+
model = BoringModel()
80+
81+
materialize_module(model)
82+
83+
with pytest.raises(AttributeError, match="'Linear' object has no attribute 'layer'"):
84+
model.layer.layer

0 commit comments

Comments
 (0)