Skip to content
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def __init__(self,
def setup(self, model):
pass

def train(self):
self.trainer.setup_trainer(self.trainer.model)
return self.train_or_test()

def teardown(self):
# Ensure if necessary all processes are finished
self.barrier()
Expand All @@ -66,6 +70,7 @@ def train_or_test(self):
if self.trainer.testing:
results = self.trainer.run_test()
else:
self.trainer.train_loop.setup_training()
results = self.trainer.train()
return results

Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,6 @@ def setup(self, model):

self.trainer.model = model

def train(self):
model = self.trainer.model

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
return results

def _step(self, model_step: Callable, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,6 @@ def ddp_train(self, process_idx, mp_queue, model):

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

Expand All @@ -198,8 +195,7 @@ def ddp_train(self, process_idx, mp_queue, model):
# allow user to configure ddp
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,6 @@ def ddp_train(self, process_idx, model):
# allow for lr schedulers as well
self.setup_optimizers(model)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

Expand All @@ -297,9 +294,8 @@ def ddp_train(self, process_idx, model):
# allow user to configure ddp
model = self.configure_ddp(model, device_ids)

# set up training routine
self.barrier('ddp_setup')
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,6 @@ def ddp_train(self, process_idx, mp_queue, model):

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

Expand All @@ -158,8 +155,7 @@ def ddp_train(self, process_idx, mp_queue, model):
# allow user to configure ddp
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,6 @@ def ddp_train(self, process_idx, model):

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

Expand All @@ -189,8 +186,7 @@ def ddp_train(self, process_idx, model):
# allow user to configure ddp
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

Expand All @@ -173,8 +170,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# allow user to configure ddp
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,6 @@ def __init_nvidia_apex(self, model):

return model

def train(self):
model = self.trainer.model
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

return results

def teardown(self):
# replace the original fwd function
self.trainer.model.forward = self.model_autocast_original_forward
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ def setup(self, model):

self.trainer.model = model

def train(self):
model = self.trainer.model

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
return results

def _step(self, model_step: Callable, args):
args[0] = self.to_device(args[0])

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def train(self):
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())

# set up training routine
self.trainer.train_loop.setup_training(self.trainer.model)
self.trainer.setup_trainer(self.trainer.model)

# train or test
results = self.train_or_test()
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
# setup TPU training
self.__setup_tpu_training(model, trainer)

# set up training routine
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

import os
from pathlib import Path
import re
from pathlib import Path
from typing import Optional, Union

import torch
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(self, trainer):
# used to validate checkpointing logic
self.has_trained = False

def restore_weights(self, model: LightningModule) -> None:
def restore_weights(self) -> None:
"""
Attempt to restore a checkpoint (e.g. weights) in this priority:
1. from HPC weights
Expand All @@ -64,7 +64,7 @@ def restore_weights(self, model: LightningModule) -> None:
rank_zero_info(f'restored hpc model from: {checkpoint_path}')

# 2. Attempt to restore states from `resume_from_checkpoint` file
elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing:
elif self.trainer.resume_from_checkpoint is not None:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)

# wait for all to catch up
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def add_progress_bar_metrics(self, metrics):

self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result, test_mode):
def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result):
self._track_callback_metrics(deprecated_eval_results, using_eval_result)
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode)
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results)

def evaluation_epoch_end(self, testing):
# reset dataloader idx
Expand Down Expand Up @@ -239,7 +239,7 @@ def prepare_eval_loop_results(self):
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
self.add_to_eval_loop_results(dl_idx, has_been_initialized)

def get_evaluate_epoch_results(self, test_mode):
def get_evaluate_epoch_results(self):
if not self.trainer.running_sanity_check:
# log all the metrics as a single dict
metrics_to_log = self.cached_results.get_epoch_log_metrics()
Expand All @@ -249,7 +249,7 @@ def get_evaluate_epoch_results(self, test_mode):
self.prepare_eval_loop_results()

# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
if self.trainer.testing and self.trainer.is_global_zero and self.trainer.verbose_test:
print('-' * 80)
for result_idx, results in enumerate(self.eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
Expand Down Expand Up @@ -330,7 +330,7 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric
if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)

def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mode):
def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
if self.trainer.running_sanity_check:
return

Expand All @@ -350,7 +350,7 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod
callback_metrics = result.callback_metrics

# in testing we don't need the callback metrics
if test_mode:
if self.trainer.testing:
callback_metrics = {}
else:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result)
Expand Down
Loading