Skip to content

Commit b7c2e0a

Browse files
authored
Trainer only references accelerator (#6039)
* Trainer only references accelerator where it can * Move teardown to the trainer, as it is reponsible for the accelerator
1 parent 7189d67 commit b7c2e0a

File tree

4 files changed

+46
-33
lines changed

4 files changed

+46
-33
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,25 @@ def setup(self, trainer: "Trainer", model: LightningModule) -> None:
7676
self.setup_optimizers(trainer)
7777
self.connect_precision_plugin(self.precision_plugin)
7878

79+
def start_training(self, trainer: 'Trainer'):
80+
self.training_type_plugin.start_training(trainer)
81+
82+
def start_testing(self, trainer: 'Trainer'):
83+
self.training_type_plugin.start_testing(trainer)
84+
85+
def start_predicting(self, trainer: 'Trainer'):
86+
self.training_type_plugin.start_predicting(trainer)
87+
88+
def pre_dispatch(self) -> None:
89+
"""Hook to do something before the training/evaluation/prediction starts."""
90+
self.training_type_plugin.pre_dispatch()
91+
self.precision_plugin.pre_dispatch()
92+
93+
def post_dispatch(self) -> None:
94+
"""Hook to do something before the training/evaluation/prediction starts."""
95+
self.training_type_plugin.post_dispatch()
96+
self.precision_plugin.post_dispatch()
97+
7998
@property
8099
def model(self) -> torch.nn.Module:
81100
"""Returns the model. This can also be a wrapped LightningModule.
@@ -224,23 +243,6 @@ def validation_step_end(self, output):
224243
"""
225244
return self.training_type_plugin.validation_step_end(output)
226245

227-
def predict(self, args):
228-
"""The prediction step.
229-
230-
Args:
231-
args: the arguments for the models predict step. Can consist of the following:
232-
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
233-
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
234-
batch_idx (int): Integer displaying index of this batch
235-
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
236-
hiddens(:class:`~torch.Tensor`): Passed in if
237-
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
238-
239-
"""
240-
batch = self.to_device(args[0])
241-
args[0] = batch
242-
return self.training_type_plugin.predict(*args)
243-
244246
def backward(
245247
self,
246248
closure_loss: torch.Tensor,
@@ -380,6 +382,10 @@ def on_save(self, checkpoint):
380382
def barrier(self, name: Optional[str] = None) -> None:
381383
self.training_type_plugin.barrier(name=name)
382384

385+
def broadcast(self, obj: object, src: int = 0) -> object:
386+
"""Broadcasts an object to all processes"""
387+
return self.training_type_plugin.broadcast(obj, src)
388+
383389
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
384390
"""
385391
Function to gather a tensor from several distributed processes
@@ -399,3 +405,12 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
399405
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
400406
"""
401407
return self.training_type_plugin.process_dataloader(dataloader)
408+
409+
@property
410+
def results(self) -> Any:
411+
"""
412+
The results of the last training/testing run will be cached here.
413+
In distributed training, we make sure to transfer the results to the appropriate master process.
414+
"""
415+
# TODO: improve these docs
416+
return self.training_type_plugin.results

pytorch_lightning/trainer/data_loading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
399399
dataloader = self._flatten_dl_only(dataloader)
400400

401401
if self.accelerator_backend is not None:
402-
self.training_type_plugin.barrier('get_dataloaders')
402+
self.accelerator_backend.barrier('get_dataloaders')
403403
return dataloader
404404

405405
def _flatten_dl_only(self, dataloaders):

pytorch_lightning/trainer/properties.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
from torch.optim import Optimizer
2222

2323
from pytorch_lightning.accelerators import Accelerator
24-
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
2524
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
2625
from pytorch_lightning.callbacks.base import Callback
2726
from pytorch_lightning.core.lightning import LightningModule
2827
from pytorch_lightning.core.optimizer import LightningOptimizer
2928
from pytorch_lightning.loggers import LightningLoggerBase
3029
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
3130
from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin
31+
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
3232
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
3333
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
3434
from pytorch_lightning.trainer.states import TrainerState
@@ -138,7 +138,7 @@ def log_dir(self) -> Optional[str]:
138138
else:
139139
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')
140140

141-
dirpath = self.training_type_plugin.broadcast(dirpath)
141+
dirpath = self.accelerator_backend.broadcast(dirpath)
142142
return dirpath
143143

144144
@property
@@ -365,7 +365,7 @@ def lightning_optimizers(self) -> List[LightningOptimizer]:
365365

366366
@property
367367
def lightning_module(self) -> LightningModule:
368-
return self.training_type_plugin.lightning_module
368+
return self.accelerator_backend.lightning_module
369369

370370
@property
371371
def optimizers(self) -> Optional[List[Optimizer]]:

pytorch_lightning/trainer/trainer.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from pytorch_lightning import _logger as log
2424
from pytorch_lightning.accelerators import Accelerator
25-
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
2625
from pytorch_lightning.callbacks import Callback
2726
from pytorch_lightning.core.datamodule import LightningDataModule
2827
from pytorch_lightning.core.lightning import LightningModule
@@ -33,6 +32,7 @@
3332
from pytorch_lightning.profiler import BaseProfiler
3433
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
3534
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
35+
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
3636
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
3737
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
3838
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
@@ -484,7 +484,7 @@ def fit(
484484
# trainer.dispatch || LIGHTNING
485485
# | ||
486486
# start_training or start_testing or start_predicting call || FLOW
487-
# from `accelerator.training_type_plugin` ||
487+
# from `accelerator` ||
488488
# | || DIRECTION
489489
# run_train or run_test or run_predict call ||
490490
# from `trainer` ||
@@ -532,26 +532,24 @@ def fit(
532532

533533
self._set_running_stage(None, model)
534534

535-
return self.training_type_plugin.results or 1
535+
return self.accelerator_backend.results or 1
536536

537537
def pre_dispatch(self):
538-
self.training_type_plugin.pre_dispatch()
539-
self.precision_plugin.pre_dispatch()
538+
self.accelerator_backend.pre_dispatch()
540539

541540
def post_dispatch(self):
542-
self.training_type_plugin.post_dispatch()
543-
self.precision_plugin.post_dispatch()
541+
self.accelerator_backend.post_dispatch()
544542
self.accelerator_backend.teardown()
545543

546544
def dispatch(self):
547545
if self.testing:
548-
self.training_type_plugin.start_testing(self)
546+
self.accelerator_backend.start_testing(self)
549547

550548
elif self.predicting:
551-
self.training_type_plugin.start_predicting(self)
549+
self.accelerator_backend.start_predicting(self)
552550

553551
else:
554-
self.training_type_plugin.start_training(self)
552+
self.accelerator_backend.start_training(self)
555553

556554
def train_or_test_or_predict(self):
557555
if self.testing:
@@ -575,7 +573,7 @@ def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule):
575573

576574
def _pre_training_routine(self):
577575
# wait for all to join if on distributed
578-
self.accelerator.training_type_plugin.barrier("setup_training")
576+
self.accelerator.barrier("setup_training")
579577

580578
# register auto-resubmit when on SLURM
581579
self.slurm_connector.register_slurm_signal_handlers()
@@ -948,7 +946,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
948946
)
949947
return {}
950948
if not self._device_type == DeviceType.TPU:
951-
self.training_type_plugin.barrier()
949+
self.accelerator_backend.barrier()
952950

953951
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
954952
model.load_state_dict(ckpt['state_dict'])

0 commit comments

Comments
 (0)