@@ -12,11 +12,10 @@ Weight Tying/Sharing is a technique where in the module weights are shared among
1212This is a common method to reduce memory consumption and is utilized in many State of the Art
1313architectures today.
1414
15- PyTorch XLA requires these weights to be tied/shared after moving the model
16- to the TPU device. To support this requirement Lightning provides a model hook which is
17- called after the model is moved to the device. Any weights that require to be tied should
18- be done in the `on_post_move_to_device ` model hook. This will ensure that the weights
19- among the modules are shared and not copied.
15+ PyTorch XLA requires these weights to be tied/shared after moving the model to the XLA device.
16+ To support this requirement, Lightning automatically finds these weights and ties them after
17+ the modules are moved to the XLA device under the hood. It will ensure that the weights among
18+ the modules are shared but not copied independently.
2019
2120PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths
2221match once the model is moved to the device. If the lengths do not match Lightning
@@ -37,9 +36,8 @@ Example:
3736 self .layer_1 = nn.Linear(32 , 10 , bias = False )
3837 self .layer_2 = nn.Linear(10 , 32 , bias = False )
3938 self .layer_3 = nn.Linear(32 , 10 , bias = False )
40- # TPU shared weights are copied independently
41- # on the XLA device and this line won't have any effect.
42- # However, it works fine for CPU and GPU.
39+ # Lightning automatically ties these weights after moving to the XLA device,
40+ # so all you need is to write the following just like on other accelerators.
4341 self .layer_3.weight = self .layer_1.weight
4442
4543 def forward (self , x ):
@@ -48,10 +46,6 @@ Example:
4846 x = self .layer_3(x)
4947 return x
5048
51- def on_post_move_to_device (self ):
52- # Weights shared after the model has been moved to TPU Device
53- self .layer_3.weight = self .layer_1.weight
54-
5549
5650 model = WeightSharingModule()
5751 trainer = Trainer(max_epochs = 1 , accelerator = " tpu" , devices = 8 )
0 commit comments