Skip to content

Commit d916973

Browse files
Refactor setup_training and remove test_mode (#5388)
* ref and fix call for on_pretrained_routine * avoid failing tests * unnecessary_call * unnecessary call in accelerators * tmpdir * rm test_mode * pep * updates * more ref * Revert "more ref" This reverts commit 5d9e95f. * more refac Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 83b1ff4 commit d916973

24 files changed

+139
-191
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ def __init__(self,
5252
def setup(self, model):
5353
pass
5454

55+
def train(self):
56+
self.trainer.setup_trainer(self.trainer.model)
57+
return self.train_or_test()
58+
5559
def teardown(self):
5660
# Ensure if necessary all processes are finished
5761
self.barrier()
@@ -66,6 +70,7 @@ def train_or_test(self):
6670
if self.trainer.testing:
6771
results = self.trainer.run_test()
6872
else:
73+
self.trainer.train_loop.setup_training()
6974
results = self.trainer.train()
7075
return results
7176

pytorch_lightning/accelerators/cpu_accelerator.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,6 @@ def setup(self, model):
5050

5151
self.trainer.model = model
5252

53-
def train(self):
54-
model = self.trainer.model
55-
56-
# set up training routine
57-
self.trainer.train_loop.setup_training(model)
58-
59-
# train or test
60-
results = self.train_or_test()
61-
return results
62-
6353
def _step(self, model_step: Callable, args):
6454
if self.trainer.amp_backend == AMPType.NATIVE:
6555
with torch.cuda.amp.autocast():

pytorch_lightning/accelerators/ddp2_accelerator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,6 @@ def ddp_train(self, process_idx, mp_queue, model):
186186

187187
self.ddp_plugin.on_after_setup_optimizers(self.trainer)
188188

189-
# set model properties before going into wrapper
190-
self.trainer.model_connector.copy_trainer_model_properties(model)
191-
192189
# 16-bit
193190
model = self.trainer.precision_connector.connect(model)
194191

@@ -198,8 +195,7 @@ def ddp_train(self, process_idx, mp_queue, model):
198195
# allow user to configure ddp
199196
model = self.configure_ddp(model, device_ids)
200197

201-
# set up training routine
202-
self.trainer.train_loop.setup_training(model)
198+
self.trainer.setup_trainer(model)
203199

204200
# train or test
205201
results = self.train_or_test()

pytorch_lightning/accelerators/ddp_accelerator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,6 @@ def ddp_train(self, process_idx, model):
285285
# allow for lr schedulers as well
286286
self.setup_optimizers(model)
287287

288-
# set model properties before going into wrapper
289-
self.trainer.model_connector.copy_trainer_model_properties(model)
290-
291288
# 16-bit
292289
model = self.trainer.precision_connector.connect(model)
293290

@@ -297,9 +294,8 @@ def ddp_train(self, process_idx, model):
297294
# allow user to configure ddp
298295
model = self.configure_ddp(model, device_ids)
299296

300-
# set up training routine
301297
self.barrier('ddp_setup')
302-
self.trainer.train_loop.setup_training(model)
298+
self.trainer.setup_trainer(model)
303299

304300
# train or test
305301
results = self.train_or_test()

pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,6 @@ def ddp_train(self, process_idx, mp_queue, model):
146146

147147
self.ddp_plugin.on_after_setup_optimizers(self.trainer)
148148

149-
# set model properties before going into wrapper
150-
self.trainer.model_connector.copy_trainer_model_properties(model)
151-
152149
# 16-bit
153150
model = self.trainer.precision_connector.connect(model)
154151

@@ -158,8 +155,7 @@ def ddp_train(self, process_idx, mp_queue, model):
158155
# allow user to configure ddp
159156
model = self.configure_ddp(model, device_ids)
160157

161-
# set up training routine
162-
self.trainer.train_loop.setup_training(model)
158+
self.trainer.setup_trainer(model)
163159

164160
# train or test
165161
results = self.train_or_test()

pytorch_lightning/accelerators/ddp_hpc_accelerator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,6 @@ def ddp_train(self, process_idx, model):
177177

178178
self.ddp_plugin.on_after_setup_optimizers(self.trainer)
179179

180-
# set model properties before going into wrapper
181-
self.trainer.model_connector.copy_trainer_model_properties(model)
182-
183180
# 16-bit
184181
model = self.trainer.precision_connector.connect(model)
185182

@@ -189,8 +186,7 @@ def ddp_train(self, process_idx, model):
189186
# allow user to configure ddp
190187
model = self.configure_ddp(model, device_ids)
191188

192-
# set up training routine
193-
self.trainer.train_loop.setup_training(model)
189+
self.trainer.setup_trainer(model)
194190

195191
# train or test
196192
results = self.train_or_test()

pytorch_lightning/accelerators/ddp_spawn_accelerator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
161161

162162
self.ddp_plugin.on_after_setup_optimizers(self.trainer)
163163

164-
# set model properties before going into wrapper
165-
self.trainer.model_connector.copy_trainer_model_properties(model)
166-
167164
# 16-bit
168165
model = self.trainer.precision_connector.connect(model)
169166

@@ -173,8 +170,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
173170
# allow user to configure ddp
174171
model = self.configure_ddp(model, device_ids)
175172

176-
# set up training routine
177-
self.trainer.train_loop.setup_training(model)
173+
self.trainer.setup_trainer(model)
178174

179175
# train or test
180176
results = self.train_or_test()

pytorch_lightning/accelerators/dp_accelerator.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,6 @@ def __init_nvidia_apex(self, model):
101101

102102
return model
103103

104-
def train(self):
105-
model = self.trainer.model
106-
# set up training routine
107-
self.trainer.train_loop.setup_training(model)
108-
109-
# train or test
110-
results = self.train_or_test()
111-
112-
return results
113-
114104
def teardown(self):
115105
# replace the original fwd function
116106
self.trainer.model.forward = self.model_autocast_original_forward

pytorch_lightning/accelerators/gpu_accelerator.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,6 @@ def setup(self, model):
5656

5757
self.trainer.model = model
5858

59-
def train(self):
60-
model = self.trainer.model
61-
62-
# set up training routine
63-
self.trainer.train_loop.setup_training(model)
64-
65-
# train or test
66-
results = self.train_or_test()
67-
return results
68-
6959
def _step(self, model_step: Callable, args):
7060
args[0] = self.to_device(args[0])
7161

pytorch_lightning/accelerators/horovod_accelerator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ def train(self):
104104
# Synchronization will be performed explicitly following backward()
105105
stack.enter_context(optimizer.skip_synchronize())
106106

107-
# set up training routine
108-
self.trainer.train_loop.setup_training(self.trainer.model)
107+
self.trainer.setup_trainer(self.trainer.model)
109108

110109
# train or test
111110
results = self.train_or_test()

0 commit comments

Comments
 (0)