1414
1515import pytest
1616import torch
17+ from torch import nn
1718
1819from pytorch_lightning import Trainer
1920from pytorch_lightning .trainer .states import TrainerState
20- from pytorch_lightning .utilities . xla_device import XLADeviceUtils
21+ from pytorch_lightning .utilities import _TPU_AVAILABLE
2122from tests .helpers .boring_model import BoringModel
2223from tests .helpers .utils import pl_multi_process_test
2324
2425
25- @pytest .mark .skipif (not XLADeviceUtils .tpu_device_exists (), reason = "test requires TPU machine" )
26+ class WeightSharingModule (BoringModel ):
27+
28+ def __init__ (self ):
29+ super ().__init__ ()
30+ self .layer_1 = nn .Linear (32 , 10 , bias = False )
31+ self .layer_2 = nn .Linear (10 , 32 , bias = False )
32+ self .layer_3 = nn .Linear (32 , 10 , bias = False )
33+ self .layer_3 .weight = self .layer_1 .weight
34+
35+ def forward (self , x ):
36+ x = self .layer_1 (x )
37+ x = self .layer_2 (x )
38+ x = self .layer_3 (x )
39+ return x
40+
41+
42+ @pytest .mark .skipif (not _TPU_AVAILABLE , reason = "test requires TPU machine" )
2643@pl_multi_process_test
2744def test_resume_training_on_cpu (tmpdir ):
2845 """ Checks if training can be resumed from a saved checkpoint on CPU"""
@@ -53,7 +70,7 @@ def test_resume_training_on_cpu(tmpdir):
5370 assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
5471
5572
56- @pytest .mark .skipif (not XLADeviceUtils . tpu_device_exists () , reason = "test requires TPU machine" )
73+ @pytest .mark .skipif (not _TPU_AVAILABLE , reason = "test requires TPU machine" )
5774@pl_multi_process_test
5875def test_if_test_works_after_train (tmpdir ):
5976 """ Ensure that .test() works after .fit() """
@@ -63,3 +80,43 @@ def test_if_test_works_after_train(tmpdir):
6380 trainer = Trainer (max_epochs = 1 , tpu_cores = 8 , default_root_dir = tmpdir , fast_dev_run = True )
6481 trainer .fit (model )
6582 assert trainer .test (model ) == 1
83+
84+
85+ @pytest .mark .skipif (not _TPU_AVAILABLE , reason = "test requires TPU machine" )
86+ @pl_multi_process_test
87+ def test_weight_tying_warning (tmpdir , capsys = None ):
88+ """
89+ Ensure a warning is thrown if model parameter lengths do not match
90+ post moving to device.
91+ """
92+
93+ model = WeightSharingModule ()
94+ trainer = Trainer (checkpoint_callback = True , max_epochs = 1 , tpu_cores = 1 )
95+
96+ with pytest .warns (UserWarning , match = r'The model layers do not match after moving to the target device.' ):
97+ result = trainer .fit (model )
98+ assert result
99+
100+
101+ @pytest .mark .skipif (not _TPU_AVAILABLE , reason = "test requires TPU machine" )
102+ @pl_multi_process_test
103+ def test_if_weights_tied (tmpdir , capsys = None ):
104+ """
105+ Test if weights are properly tied on `on_post_move_to_device`.
106+ Ensure no warning for parameter mismatch is thrown.
107+ """
108+
109+ class Model (WeightSharingModule ):
110+
111+ def on_post_move_to_device (self ):
112+ self .layer_3 .weight = self .layer_1 .weight
113+
114+ model = Model ()
115+ trainer = Trainer (checkpoint_callback = True , max_epochs = 1 , tpu_cores = 1 )
116+
117+ with pytest .warns (UserWarning ) as warnings :
118+ result = trainer .fit (model )
119+ assert result
120+
121+ assert not list (filter (lambda x : 'The model layers do not match' in str (x ), warnings .list ))
122+ assert trainer .test (model ) == 1
0 commit comments