Skip to content

Commit dd25263

Browse files
kaushikb11ninginthecloud
authored andcommitted
Add method to TPUSpawn plugin to override how models are setup (Lightning-AI#10039)
1 parent 9f08ef3 commit dd25263

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222
import torch.multiprocessing as mp
23+
from torch.nn import Module
2324
from torch.utils.data import DataLoader
2425

2526
import pytorch_lightning as pl
@@ -118,6 +119,9 @@ def pre_dispatch(self):
118119
def setup(self) -> None:
119120
self.create_mp_queue()
120121

122+
def _setup_model(self, model: Module) -> Module:
123+
return model
124+
121125
def create_mp_queue(self):
122126
self.start_method = "fork"
123127
smp = mp.get_context(self.start_method)

0 commit comments

Comments
 (0)