|
23 | 23 | from pytorch_lightning.accelerators.tpu import TPUAccelerator |
24 | 24 | from pytorch_lightning.callbacks import Callback |
25 | 25 | from pytorch_lightning.plugins import TPUSpawnPlugin |
| 26 | +from pytorch_lightning.utilities import find_shared_parameters |
26 | 27 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
27 | 28 | from tests.helpers.boring_model import BoringModel |
28 | 29 | from tests.helpers.runif import RunIf |
@@ -80,37 +81,6 @@ def test_if_test_works_after_train(tmpdir): |
80 | 81 | assert len(trainer.test(model)) == 1 |
81 | 82 |
|
82 | 83 |
|
83 | | -@RunIf(tpu=True) |
84 | | -@pl_multi_process_test |
85 | | -def test_weight_tying_warning(tmpdir, capsys=None): |
86 | | - """Ensure a warning is thrown if model parameter lengths do not match post moving to device.""" |
87 | | - |
88 | | - model = WeightSharingModule() |
89 | | - trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) |
90 | | - |
91 | | - with pytest.warns(UserWarning, match=r"The model layers do not match after moving to the target device."): |
92 | | - trainer.fit(model) |
93 | | - |
94 | | - |
95 | | -@RunIf(tpu=True) |
96 | | -@pl_multi_process_test |
97 | | -def test_if_weights_tied(tmpdir, capsys=None): |
98 | | - """Test if weights are properly tied on `on_post_move_to_device`. |
99 | | -
|
100 | | - Ensure no warning for parameter mismatch is thrown. |
101 | | - """ |
102 | | - |
103 | | - class Model(WeightSharingModule): |
104 | | - def on_post_move_to_device(self): |
105 | | - self.layer_3.weight = self.layer_1.weight |
106 | | - |
107 | | - model = Model() |
108 | | - trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) |
109 | | - |
110 | | - with pytest.warns(UserWarning, match="The model layers do not match"): |
111 | | - trainer.fit(model) |
112 | | - |
113 | | - |
114 | 84 | @RunIf(tpu=True) |
115 | 85 | def test_accelerator_tpu(): |
116 | 86 |
|
@@ -257,3 +227,49 @@ def test_ddp_cpu_not_supported_on_tpus(): |
257 | 227 |
|
258 | 228 | with pytest.raises(MisconfigurationException, match="`accelerator='ddp_cpu'` is not supported on TPU machines"): |
259 | 229 | Trainer(accelerator="ddp_cpu") |
| 230 | + |
| 231 | + |
| 232 | +@RunIf(tpu=True) |
| 233 | +def test_auto_parameters_tying_tpus(tmpdir): |
| 234 | + |
| 235 | + model = WeightSharingModule() |
| 236 | + shared_params = find_shared_parameters(model) |
| 237 | + |
| 238 | + assert shared_params[0] == ["layer_1.weight", "layer_3.weight"] |
| 239 | + |
| 240 | + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, tpu_cores=8, max_epochs=1) |
| 241 | + trainer.fit(model) |
| 242 | + |
| 243 | + assert torch.all(torch.eq(model.layer_1.weight, model.layer_3.weight)) |
| 244 | + |
| 245 | + |
| 246 | +@RunIf(tpu=True) |
| 247 | +def test_auto_parameters_tying_tpus_nested_module(tmpdir): |
| 248 | + class SubModule(nn.Module): |
| 249 | + def __init__(self, layer): |
| 250 | + super().__init__() |
| 251 | + self.layer = layer |
| 252 | + |
| 253 | + def forward(self, x): |
| 254 | + return self.layer(x) |
| 255 | + |
| 256 | + class NestedModule(BoringModel): |
| 257 | + def __init__(self): |
| 258 | + super().__init__() |
| 259 | + self.layer = nn.Linear(32, 10, bias=False) |
| 260 | + self.net_a = SubModule(self.layer) |
| 261 | + self.layer_2 = nn.Linear(10, 32, bias=False) |
| 262 | + self.net_b = SubModule(self.layer) |
| 263 | + |
| 264 | + def forward(self, x): |
| 265 | + x = self.net_a(x) |
| 266 | + x = self.layer_2(x) |
| 267 | + x = self.net_b(x) |
| 268 | + return x |
| 269 | + |
| 270 | + model = NestedModule() |
| 271 | + |
| 272 | + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, tpu_cores=8, max_epochs=1) |
| 273 | + trainer.fit(model) |
| 274 | + |
| 275 | + assert torch.all(torch.eq(model.net_a.layer.weight, model.net_b.layer.weight)) |
0 commit comments