From 0de3c96fcbba339ad024daf4ad8f71bfe44640e0 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Mon, 11 Oct 2021 17:36:38 -0700 Subject: [PATCH 1/7] Directly call TrainingTypePlugin APIs instead of going through the Accelerator --- pytorch_lightning/accelerators/accelerator.py | 230 +++++++++++++++++- .../loops/dataloader/evaluation_loop.py | 2 +- .../loops/dataloader/prediction_loop.py | 2 +- pytorch_lightning/loops/fit_loop.py | 4 +- .../loops/optimization/manual_loop.py | 2 +- .../loops/optimization/optimizer_loop.py | 2 +- .../connectors/checkpoint_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 38 +-- tests/accelerators/test_cpu.py | 4 +- tests/deprecated_api/test_remove_1-6.py | 70 +++++- ..._ddp_fully_sharded_with_full_state_dict.py | 4 +- tests/plugins/test_ddp_plugin.py | 8 +- 12 files changed, 323 insertions(+), 47 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index cfed45e1db186..ab670ecbb80f4 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -59,15 +59,31 @@ def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: Trai self.optimizer_frequencies: List = [] def connect(self, model: "pl.LightningModule") -> None: - """Transfers ownership of the model to this plugin.""" + """Transfers ownership of the model to this plugin. + + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call + ``training_type_plugin.connect`` directly. + """ + rank_zero_deprecation( + "`Accelerator.connect` is deprecated in v1.5 and will be removed in v1.6. " + "`connect` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.connect(model) def setup_environment(self) -> None: """Setup any processes or distributed connections. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.setup_environment`` directly. + This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete. """ + rank_zero_deprecation( + "`Accelerator.setup_environment` is deprecated in v1.5 and will be removed in v1.6. " + "`setup_environment` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.setup_environment() def setup(self, trainer: "pl.Trainer") -> None: @@ -82,12 +98,39 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_precision_plugin() def start_training(self, trainer: "pl.Trainer") -> None: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.start_training`` directly. + """ + rank_zero_deprecation( + "`Accelerator.start_training` is deprecated in v1.5 and will be removed in v1.6. " + "`start_training` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.start_training(trainer) def start_evaluating(self, trainer: "pl.Trainer") -> None: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.start_evaluating`` directly. + """ + rank_zero_deprecation( + "`Accelerator.start_evaluating` is deprecated in v1.5 and will be removed in v1.6. " + "`start_evaluating` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.start_evaluating(trainer) def start_predicting(self, trainer: "pl.Trainer") -> None: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.start_predicting`` directly. + """ + rank_zero_deprecation( + "`Accelerator.start_predicting` is deprecated in v1.5 and will be removed in v1.6. " + "`start_predicting` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.start_predicting(trainer) def pre_dispatch(self, trainer: "pl.Trainer") -> None: @@ -146,8 +189,16 @@ def root_device(self) -> torch.device: def teardown(self) -> None: """This method is called to teardown the training process. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.teardown`` directly. + It is the right place to release memory and free other resources. """ + rank_zero_deprecation( + "`Accelerator.teardown` is deprecated in v1.5 and will be removed in v1.6. " + "`teardown` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.teardown() def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: @@ -177,6 +228,15 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: return self.training_type_plugin.training_step(*step_kwargs.values()) def post_training_step(self) -> None: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.post_training_step`` directly. + """ + rank_zero_deprecation( + "`Accelerator.post_training_step` is deprecated in v1.5 and will be removed in v1.6. " + "`post_training_step` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.post_training_step() def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -206,25 +266,49 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: """A hook to do something at the end of the training step. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.training_step_end`` directly. + Args: output: the output of the training step """ + rank_zero_deprecation( + "`Accelerator.training_step_end` is deprecated in v1.5 and will be removed in v1.6. " + "`training_step_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.training_step_end(output) def test_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: """A hook to do something at the end of the test step. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.test_step_end`` directly. + Args: output: the output of the test step """ + rank_zero_deprecation( + "`Accelerator.test_step_end` is deprecated in v1.5 and will be removed in v1.6. " + "`test_step_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.test_step_end(output) def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: """A hook to do something at the end of the validation step. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.validation_step_end`` directly. + Args: output: the output of the validation step """ + rank_zero_deprecation( + "`Accelerator.validation_step_end` is deprecated in v1.5 and will be removed in v1.6. " + "`validation_step_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.validation_step_end(output) def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: @@ -330,8 +414,16 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: """Returns state of model. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.lightning_module_state_dict`` directly. + Allows for syncing/collating model state from processes in custom plugins. """ + rank_zero_deprecation( + "`Accelerator.lightning_module_state_dict` is deprecated in v1.5 and will be removed in v1.6. " + "`lightning_module_state_dict` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.lightning_module_state_dict() def barrier(self, name: Optional[str] = None) -> None: @@ -342,7 +434,7 @@ def barrier(self, name: Optional[str] = None) -> None: """ rank_zero_deprecation( "`Accelerator.barrier` is deprecated in v1.5 and will be removed in v1.6. " - "Barrier logic is implemented directly in the `TrainingTypePlugin` implementations." + "`Barrier` logic is implemented directly in the `TrainingTypePlugin` implementations." ) self.training_type_plugin.barrier(name=name) @@ -360,7 +452,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: """ rank_zero_deprecation( "`Accelerator.broadcast` is deprecated in v1.5 and will be removed in v1.6. " - "Broadcast logic is implemented directly in the `TrainingTypePlugin` implementations." + "`Broadcast` logic is implemented directly in the `TrainingTypePlugin` implementations." ) return self.training_type_plugin.broadcast(obj, src) @@ -381,24 +473,40 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo """ rank_zero_deprecation( "`Accelerator.all_gather` is deprecated in v1.5 and will be removed in v1.6. " - "All-gather logic is implemented directly in the `TrainingTypePlugin` implementations." + "`all_gather` logic is implemented directly in the `TrainingTypePlugin` implementations." ) return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: """Wraps the dataloader if necessary. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.process_dataloader`` directly. + Args: dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` """ + rank_zero_deprecation( + "`Accelerator.process_dataloader` is deprecated in v1.5 and will be removed in v1.6. " + "`process_dataloader` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.process_dataloader(dataloader) @property def results(self) -> Any: """The results of the last run will be cached within the training type plugin. + .. deprecated:: v1.5 + This property is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.results`` directly. + In distributed training, we make sure to transfer the results to the appropriate master process. """ + rank_zero_deprecation( + "`Accelerator.results` is deprecated in v1.5 and will be removed in v1.6. " + "Accesse results directly from the `TrainingTypePlugin`." + ) return self.training_type_plugin.results @contextlib.contextmanager @@ -417,10 +525,18 @@ def model_sharded_context(self) -> Generator[None, None, None]: def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.save_checkpoint`` directly. + Args: checkpoint: dict containing model and trainer state filepath: write-target file's path """ + rank_zero_deprecation( + "`Accelerator.save_checkpoint` is deprecated in v1.5 and will be removed in v1.6. " + "`save_checkpoint` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.save_checkpoint(checkpoint, filepath) @property @@ -429,9 +545,17 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: `TrainingTypePlugin` requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set. + .. deprecated:: v1.5 + This property is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.setup_optimizers_in_pre_dispatch`` directly. + Returns: If True, delay setup optimizers until `pre_dispatch`, else call within `setup`. """ + rank_zero_deprecation( + "`Accelerator.setup_optimizers_in_pre_dispatch` is deprecated in v1.5 and will be removed in v1.6. " + "Accesse `setup_optimizers_in_pre_dispatch directly` from the `TrainingTypePlugin`." + ) return self.training_type_plugin.setup_optimizers_in_pre_dispatch @property @@ -439,9 +563,17 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool: """Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint. + .. deprecated:: v1.5 + This property is deprecated in v1.5 and will be removed in v1.6. + Please call ``training_type_plugin.restore_checkpoint_after_pre_dispatch`` directly. + Returns: If true, restore checkpoint after pre_dispatch. """ + rank_zero_deprecation( + "`Accelerator.restore_checkpoint_after_pre_dispatch` is deprecated in v1.5 and will be removed in v1.6. " + "Accesse `restore_checkpoint_after_pre_dispatch` directly from the `TrainingTypePlugin`." + ) return self.training_type_plugin.restore_checkpoint_after_pre_dispatch def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @@ -456,38 +588,110 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: raise NotImplementedError def on_train_start(self) -> None: - """Called when train begins.""" + """Called when train begins. + + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call + ``training_type_plugin.on_train_start`` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_train_start` is deprecated in v1.5 and will be removed in v1.6. " + "`on_train_start` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_train_start() def on_validation_start(self) -> None: - """Called when validation begins.""" + """Called when validation begins. + + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call + ``training_type_plugin.on_validation_start`` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_validation_start` is deprecated in v1.5 and will be removed in v1.6. " + "`on_validation_start` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_validation_start() def on_test_start(self) -> None: - """Called when test begins.""" + """Called when test begins. + + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call + ``training_type_plugin.on_test_start`` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_test_start` is deprecated in v1.5 and will be removed in v1.6. " + "`on_test_start` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_test_start() def on_predict_start(self) -> None: - """Called when predict begins.""" + """Called when predict begins. + + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call + ``training_type_plugin.on_predict_start`` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_predict_start` is deprecated in v1.5 and will be removed in v1.6. " + "`on_predict_start` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_predict_start() def on_validation_end(self) -> None: - """Called when validation ends.""" + """Called when validation ends. + + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call + ``training_type_plugin.on_validation_end`` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_validation_end` is deprecated in v1.5 and will be removed in v1.6. " + "`on_validation_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_validation_end() def on_test_end(self) -> None: - """Called when test end.""" + """Called when test end. + + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call + ``training_type_plugin.on_test_end`` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_test_end` is deprecated in v1.5 and will be removed in v1.6. " + "`on_test_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_test_end() def on_predict_end(self) -> None: - """Called when predict ends.""" + """Called when predict ends. + + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call + ``training_type_plugin.on_predict_end`` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_predict_end` is deprecated in v1.5 and will be removed in v1.6. " + "`on_predict_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_predict_end() def on_train_end(self) -> None: - """Called when train ends.""" + """Called when train ends. + + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call + ``training_type_plugin.on_train_end`` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_train_end` is deprecated in v1.5 and will be removed in v1.6. " + "`on_train_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_train_end() # TODO: Update this in v1.7 (deprecation: #9816) def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - """Called in the training loop before anything happens for that batch.""" + """Called in the training loop before anything happens for that batch. + + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call + ``training_type_plugin.on_train_batch_start`` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_train_batch_start` is deprecated in v1.5 and will be removed in v1.6. " + "`on_train_batch_start` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_train_batch_start(batch, batch_idx) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 92c7d36cfd0d8..119e7f6c5472c 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -100,7 +100,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: void(*args, **kwargs) dataloader_idx: int = self.current_dataloader_idx - dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) + dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader) dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx) dl_max_batches = self._max_batches[dataloader_idx] diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index d4a6ab6d29cef..cf40316312107 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -84,7 +84,7 @@ def on_run_start(self) -> None: def advance(self, *args: Any, **kwargs: Any) -> None: """Predicts one entire dataloader.""" void(*args, **kwargs) - dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) + dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader) dataloader_iter = enumerate(dataloader) dl_max_batches = self.max_batches[self.current_dataloader_idx] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 0d16c978fb374..6c692bb2dc6eb 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -204,7 +204,7 @@ def on_advance_start(self) -> None: def advance(self) -> None: """Runs one whole epoch.""" - dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) + dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader) data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader) with self.trainer.profiler.profile("run_training_epoch"): @@ -234,7 +234,7 @@ def on_run_end(self) -> None: self.trainer.call_hook("on_train_end") # give accelerators a chance to finish - self.trainer.accelerator.on_train_end() + self.trainer.training_type_plugin.on_train_end() def teardown(self) -> None: self.epoch_loop.teardown() diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 21ad8738ab5bb..4c8bf157d331b 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -107,7 +107,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] lightning_module._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() + self.trainer.training_type_plugin.post_training_step() del step_kwargs diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 3a73795014c80..4ad85cc4650f2 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -450,7 +450,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos lightning_module._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() + self.trainer.training_type_plugin.post_training_step() del step_kwargs diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index d48ab3f518443..da6a81e8add44 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -465,7 +465,7 @@ def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: weights_only: saving model weights only """ _checkpoint = self.dump_checkpoint(weights_only) - self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) + self.trainer.training_type_plugin.save_checkpoint(_checkpoint, filepath) def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: metrics = ( @@ -478,7 +478,7 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: metric.persistent(True) metric.sync() - state_dict = self.trainer.accelerator.lightning_module_state_dict() + state_dict = self.trainer.training_type_plugin.lightning_module_state_dict() for metric in metrics: # sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 01acae35fd46c..e2336c56e36d6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1020,24 +1020,24 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.callback_connector.attach_model_logging_functions(model) # attach model to the training type plugin - self.accelerator.connect(model) + self.training_type_plugin.connect(model) # hook self.data_connector.prepare_data() self.callback_connector._attach_model_callbacks() - if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch: + if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch: self._load_checkpoint_weights() # ---------------------------- # SET UP TRAINING # ---------------------------- self.call_hook("on_before_accelerator_backend_setup") - self.accelerator.setup_environment() + self.training_type_plugin.setup_environment() self._call_setup_hook() # allow user to setup lightning_module in accelerator environment # check if we should delay restoring checkpoint till later - if not self.accelerator.restore_checkpoint_after_pre_dispatch: + if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch: self.checkpoint_connector.resume_start() self._restore_modules_and_callbacks() @@ -1055,9 +1055,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, | || {self._dispatch} || | || LIGHTNING - {self.accelerator.start_training} || - or {self.accelerator.start_evaluating} || - or {self.accelerator.start_predicting} || FLOW + {self.training_type_plugin.start_training} || + or {self.training_type_plugin.start_evaluating} || + or {self.training_type_plugin.start_predicting} || FLOW | || {self.run_stage} || | || DIRECTION @@ -1087,7 +1087,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() - if self.accelerator.restore_checkpoint_after_pre_dispatch: + if self.training_type_plugin.restore_checkpoint_after_pre_dispatch: if self._ckpt_path: self._load_checkpoint_weights() @@ -1119,7 +1119,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.state.status = TrainerStatus.FINISHED self.state.stage = None - return self.accelerator.results + return self.training_type_plugin.results def _pre_dispatch(self): self.accelerator.pre_dispatch(self) @@ -1166,18 +1166,18 @@ def _post_dispatch(self): self.accelerator.post_dispatch(self) # these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns # which need to happen before. - self.accelerator.teardown() + self.training_type_plugin.teardown() self.data_connector.teardown() self._active_loop.teardown() self.logger_connector.teardown() def _dispatch(self): if self.evaluating: - self.accelerator.start_evaluating(self) + self.training_type_plugin.start_evaluating(self) elif self.predicting: - self.accelerator.start_predicting(self) + self.training_type_plugin.start_predicting(self) else: - self.accelerator.start_training(self) + self.training_type_plugin.start_training(self) def run_stage(self): self.accelerator.dispatch(self) @@ -1509,22 +1509,26 @@ def precision_plugin(self) -> PrecisionPlugin: @property def global_rank(self) -> int: - return self.accelerator.training_type_plugin.global_rank + return self.training_type_plugin.global_rank @property def local_rank(self) -> int: # some training types define a local rank - return getattr(self.accelerator.training_type_plugin, "local_rank", 0) + return getattr(self.training_type_plugin, "local_rank", 0) @property def node_rank(self) -> int: # some training types define a local rank - return getattr(self.accelerator.training_type_plugin, "node_rank", 0) + return getattr(self.training_type_plugin, "node_rank", 0) @property def world_size(self) -> int: # some training types define a world size - return getattr(self.accelerator.training_type_plugin, "world_size", 1) + return getattr(self.training_type_plugin, "world_size", 1) + + @property + def should_rank_save_checkpoint(self) -> bool: + return self.training_type_plugin.should_rank_save_checkpoint @property def _distrib_type(self) -> DistributedType: diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index f95d182f9e5e1..697fae1644b1b 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -43,7 +43,7 @@ def test_restore_checkpoint_after_pre_dispatch_default(): """Assert default for restore_checkpoint_after_pre_dispatch is False.""" plugin = SingleDevicePlugin(torch.device("cpu")) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) - assert not accelerator.restore_checkpoint_after_pre_dispatch + assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch assert not plugin.restore_checkpoint_after_pre_dispatch @@ -77,7 +77,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO()) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) - assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch + assert accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch trainer = Trainer( diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 546b25a54aa83..6a54026724b68 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -330,7 +330,7 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import(): from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401 -def test_v1_6_0_deprecated_accelerator_collective(): +def test_v1_6_0_deprecated_accelerator_pass_through_functions(tmpdir): from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type import SingleDevicePlugin @@ -347,3 +347,71 @@ def test_v1_6_0_deprecated_accelerator_collective(): with pytest.deprecated_call(match="will be removed in v1.6"): tensor = torch.rand(2, 2, requires_grad=True) accelerator.all_gather(tensor) + + with pytest.deprecated_call(match="will be removed in v1.6"): + model = BoringModel() + accelerator.connect(model) + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.setup_environment() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.teardown() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.post_training_step() + + with pytest.deprecated_call(match="will be removed in v1.6"): + tensor = torch.rand(2, 2, requires_grad=True) + accelerator.training_step_end(tensor) + + with pytest.deprecated_call(match="will be removed in v1.6"): + tensor = torch.rand(2, 2, requires_grad=True) + accelerator.test_step_end(tensor) + + with pytest.deprecated_call(match="will be removed in v1.6"): + tensor = torch.rand(2, 2, requires_grad=True) + accelerator.validation_step_end(tensor) + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.lightning_module_state_dict() + + with pytest.deprecated_call(match="will be removed in v1.6"): + dl = model.train_dataloader() + accelerator.process_dataloader(dl) + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.results + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.setup_optimizers_in_pre_dispatch + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.restore_checkpoint_after_pre_dispatch + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_train_start() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_validation_start() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_test_start() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_predict_start() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_validation_end() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_test_end() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_predict_end() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_train_end() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_train_batch_start(batch=None, batch_idx=0) diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 61688b8847778..6aeb79f80b255 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -24,7 +24,7 @@ def test_invalid_on_cpu(tmpdir): ): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins="fsdp") assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin) - trainer.accelerator.setup_environment() + trainer.training_type_plugin.setup_environment() @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) @@ -120,7 +120,7 @@ def test_fully_sharded_plugin_checkpoint_multi_gpus(tmpdir): def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): # Use FullySharded to get the state dict for the sake of comparison - model_state_dict = trainer.accelerator.lightning_module_state_dict() + model_state_dict = trainer.training_type_plugin.lightning_module_state_dict() if trainer.is_global_zero: saved_model = cls.load_from_checkpoint(ckpt_path) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index 03cc0e1ff7beb..618e191dc8ee6 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -108,8 +108,8 @@ def test_ddp_configure_ddp(): ) # test wrap the model if fitting trainer.state.fn = TrainerFn.FITTING - trainer.accelerator.connect(model) - trainer.accelerator.setup_environment() + trainer.training_type_plugin.connect(model) + trainer.training_type_plugin.setup_environment() trainer.accelerator.setup(trainer) trainer.lightning_module.trainer = trainer assert isinstance(trainer.model, LightningModule) @@ -122,8 +122,8 @@ def test_ddp_configure_ddp(): plugins=[ddp_plugin], ) # test do not wrap the model if trainerFN is not fitting - trainer.accelerator.connect(model) - trainer.accelerator.setup_environment() + trainer.training_type_plugin.connect(model) + trainer.training_type_plugin.setup_environment() trainer.accelerator.setup(trainer) trainer.lightning_module.trainer = trainer trainer._pre_dispatch() From 16238dda872e4ac876a5c005116bb6ebbe556562 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Mon, 11 Oct 2021 17:43:13 -0700 Subject: [PATCH 2/7] Directly call TrainingTypePlugin APIs instead of going through the Accelerator --- pytorch_lightning/accelerators/accelerator.py | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index ab670ecbb80f4..f41d52bf11f4e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -62,7 +62,7 @@ def connect(self, model: "pl.LightningModule") -> None: """Transfers ownership of the model to this plugin. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - ``training_type_plugin.connect`` directly. + `training_type_plugin.connect` directly. """ rank_zero_deprecation( "`Accelerator.connect` is deprecated in v1.5 and will be removed in v1.6. " @@ -75,7 +75,7 @@ def setup_environment(self) -> None: .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.setup_environment`` directly. + Please call `training_type_plugin.setup_environment` directly. This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete. @@ -101,7 +101,7 @@ def start_training(self, trainer: "pl.Trainer") -> None: """ .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.start_training`` directly. + Please call `training_type_plugin.start_training` directly. """ rank_zero_deprecation( "`Accelerator.start_training` is deprecated in v1.5 and will be removed in v1.6. " @@ -113,7 +113,7 @@ def start_evaluating(self, trainer: "pl.Trainer") -> None: """ .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.start_evaluating`` directly. + Please call `training_type_plugin.start_evaluating` directly. """ rank_zero_deprecation( "`Accelerator.start_evaluating` is deprecated in v1.5 and will be removed in v1.6. " @@ -125,7 +125,7 @@ def start_predicting(self, trainer: "pl.Trainer") -> None: """ .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.start_predicting`` directly. + Please call `training_type_plugin.start_predicting` directly. """ rank_zero_deprecation( "`Accelerator.start_predicting` is deprecated in v1.5 and will be removed in v1.6. " @@ -191,7 +191,7 @@ def teardown(self) -> None: .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.teardown`` directly. + Please call `training_type_plugin.teardown` directly. It is the right place to release memory and free other resources. """ @@ -231,7 +231,7 @@ def post_training_step(self) -> None: """ .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.post_training_step`` directly. + Please call `training_type_plugin.post_training_step` directly. """ rank_zero_deprecation( "`Accelerator.post_training_step` is deprecated in v1.5 and will be removed in v1.6. " @@ -268,7 +268,7 @@ def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.training_step_end`` directly. + Please call `training_type_plugin.training_step_end` directly. Args: output: the output of the training step @@ -284,7 +284,7 @@ def test_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.test_step_end`` directly. + Please call `training_type_plugin.test_step_end` directly. Args: output: the output of the test step @@ -300,7 +300,7 @@ def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OU .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.validation_step_end`` directly. + Please call `training_type_plugin.validation_step_end` directly. Args: output: the output of the validation step @@ -416,7 +416,7 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.lightning_module_state_dict`` directly. + Please call `training_type_plugin.lightning_module_state_dict` directly. Allows for syncing/collating model state from processes in custom plugins. """ @@ -430,7 +430,7 @@ def barrier(self, name: Optional[str] = None) -> None: """ .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.barrier`` directly. + Please call `training_type_plugin.barrier` directly. """ rank_zero_deprecation( "`Accelerator.barrier` is deprecated in v1.5 and will be removed in v1.6. " @@ -444,7 +444,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.broadcast`` directly. + Please call `training_type_plugin.broadcast` directly. Args: obj: Object to broadcast to all process, usually a tensor or collection of tensors. @@ -461,7 +461,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.all_gather`` directly. + Please call `training_type_plugin.all_gather` directly. Args: tensor: tensor of shape (batch, ...) @@ -482,7 +482,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.process_dataloader`` directly. + Please call `training_type_plugin.process_dataloader` directly. Args: dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` @@ -499,7 +499,7 @@ def results(self) -> Any: .. deprecated:: v1.5 This property is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.results`` directly. + Please call `training_type_plugin.results` directly. In distributed training, we make sure to transfer the results to the appropriate master process. """ @@ -527,7 +527,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.save_checkpoint`` directly. + Please call `training_type_plugin.save_checkpoint` directly. Args: checkpoint: dict containing model and trainer state @@ -547,7 +547,7 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: .. deprecated:: v1.5 This property is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.setup_optimizers_in_pre_dispatch`` directly. + Please call `training_type_plugin.setup_optimizers_in_pre_dispatch` directly. Returns: If True, delay setup optimizers until `pre_dispatch`, else call within `setup`. @@ -565,7 +565,7 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool: .. deprecated:: v1.5 This property is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.restore_checkpoint_after_pre_dispatch`` directly. + Please call `training_type_plugin.restore_checkpoint_after_pre_dispatch` directly. Returns: If true, restore checkpoint after pre_dispatch. @@ -591,7 +591,7 @@ def on_train_start(self) -> None: """Called when train begins. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - ``training_type_plugin.on_train_start`` directly. + `training_type_plugin.on_train_start` directly. """ rank_zero_deprecation( "`Accelerator.on_train_start` is deprecated in v1.5 and will be removed in v1.6. " @@ -603,7 +603,7 @@ def on_validation_start(self) -> None: """Called when validation begins. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - ``training_type_plugin.on_validation_start`` directly. + `training_type_plugin.on_validation_start` directly. """ rank_zero_deprecation( "`Accelerator.on_validation_start` is deprecated in v1.5 and will be removed in v1.6. " @@ -615,7 +615,7 @@ def on_test_start(self) -> None: """Called when test begins. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - ``training_type_plugin.on_test_start`` directly. + `training_type_plugin.on_test_start` directly. """ rank_zero_deprecation( "`Accelerator.on_test_start` is deprecated in v1.5 and will be removed in v1.6. " @@ -627,7 +627,7 @@ def on_predict_start(self) -> None: """Called when predict begins. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - ``training_type_plugin.on_predict_start`` directly. + `training_type_plugin.on_predict_start` directly. """ rank_zero_deprecation( "`Accelerator.on_predict_start` is deprecated in v1.5 and will be removed in v1.6. " @@ -639,7 +639,7 @@ def on_validation_end(self) -> None: """Called when validation ends. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - ``training_type_plugin.on_validation_end`` directly. + `training_type_plugin.on_validation_end` directly. """ rank_zero_deprecation( "`Accelerator.on_validation_end` is deprecated in v1.5 and will be removed in v1.6. " @@ -651,7 +651,7 @@ def on_test_end(self) -> None: """Called when test end. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - ``training_type_plugin.on_test_end`` directly. + `training_type_plugin.on_test_end` directly. """ rank_zero_deprecation( "`Accelerator.on_test_end` is deprecated in v1.5 and will be removed in v1.6. " @@ -663,7 +663,7 @@ def on_predict_end(self) -> None: """Called when predict ends. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - ``training_type_plugin.on_predict_end`` directly. + `training_type_plugin.on_predict_end` directly. """ rank_zero_deprecation( "`Accelerator.on_predict_end` is deprecated in v1.5 and will be removed in v1.6. " @@ -675,7 +675,7 @@ def on_train_end(self) -> None: """Called when train ends. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - ``training_type_plugin.on_train_end`` directly. + `training_type_plugin.on_train_end` directly. """ rank_zero_deprecation( "`Accelerator.on_train_end` is deprecated in v1.5 and will be removed in v1.6. " @@ -688,7 +688,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = """Called in the training loop before anything happens for that batch. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - ``training_type_plugin.on_train_batch_start`` directly. + `training_type_plugin.on_train_batch_start` directly. """ rank_zero_deprecation( "`Accelerator.on_train_batch_start` is deprecated in v1.5 and will be removed in v1.6. " From e6dba23a90832f2c785b19e3079bc446d918a776 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 12 Oct 2021 15:58:37 +0100 Subject: [PATCH 3/7] add changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f12dda513629..5c5036b693349 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896)) +- Changed `training_type_plugin` in favor of `accelerator` within the loops ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901)) + + ### Deprecated - Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175)) From c98c8b30a8221adf9b5e5a38996f86817183625f Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Tue, 12 Oct 2021 09:43:39 -0700 Subject: [PATCH 4/7] fix format --- pytorch_lightning/accelerators/accelerator.py | 84 ++++++++++--------- pytorch_lightning/trainer/trainer.py | 4 +- tests/deprecated_api/test_remove_1-6.py | 8 +- ..._ddp_fully_sharded_with_full_state_dict.py | 2 +- tests/plugins/test_ddp_plugin.py | 4 +- 5 files changed, 49 insertions(+), 53 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f41d52bf11f4e..1dc37349d0707 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -61,6 +61,8 @@ def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: Trai def connect(self, model: "pl.LightningModule") -> None: """Transfers ownership of the model to this plugin. + See deprecation warning below. + .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call `training_type_plugin.connect` directly. """ @@ -73,17 +75,9 @@ def connect(self, model: "pl.LightningModule") -> None: def setup_environment(self) -> None: """Setup any processes or distributed connections. - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.setup_environment` directly. - This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete. """ - rank_zero_deprecation( - "`Accelerator.setup_environment` is deprecated in v1.5 and will be removed in v1.6. " - "`setup_environment` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) self.training_type_plugin.setup_environment() def setup(self, trainer: "pl.Trainer") -> None: @@ -189,16 +183,8 @@ def root_device(self) -> torch.device: def teardown(self) -> None: """This method is called to teardown the training process. - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.teardown` directly. - It is the right place to release memory and free other resources. """ - rank_zero_deprecation( - "`Accelerator.teardown` is deprecated in v1.5 and will be removed in v1.6. " - "`teardown` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) self.training_type_plugin.teardown() def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: @@ -588,22 +574,17 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: raise NotImplementedError def on_train_start(self) -> None: - """Called when train begins. - - .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - `training_type_plugin.on_train_start` directly. - """ - rank_zero_deprecation( - "`Accelerator.on_train_start` is deprecated in v1.5 and will be removed in v1.6. " - "`on_train_start` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) + """Called when train begins.""" return self.training_type_plugin.on_train_start() def on_validation_start(self) -> None: """Called when validation begins. - .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - `training_type_plugin.on_validation_start` directly. + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_validation_start` directly. """ rank_zero_deprecation( "`Accelerator.on_validation_start` is deprecated in v1.5 and will be removed in v1.6. " @@ -614,8 +595,11 @@ def on_validation_start(self) -> None: def on_test_start(self) -> None: """Called when test begins. - .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - `training_type_plugin.on_test_start` directly. + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_test_start` directly. """ rank_zero_deprecation( "`Accelerator.on_test_start` is deprecated in v1.5 and will be removed in v1.6. " @@ -626,8 +610,11 @@ def on_test_start(self) -> None: def on_predict_start(self) -> None: """Called when predict begins. - .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - `training_type_plugin.on_predict_start` directly. + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_predict_start` directly. """ rank_zero_deprecation( "`Accelerator.on_predict_start` is deprecated in v1.5 and will be removed in v1.6. " @@ -638,8 +625,11 @@ def on_predict_start(self) -> None: def on_validation_end(self) -> None: """Called when validation ends. - .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - `training_type_plugin.on_validation_end` directly. + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_validation_end` directly. """ rank_zero_deprecation( "`Accelerator.on_validation_end` is deprecated in v1.5 and will be removed in v1.6. " @@ -650,8 +640,11 @@ def on_validation_end(self) -> None: def on_test_end(self) -> None: """Called when test end. - .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - `training_type_plugin.on_test_end` directly. + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_test_end` directly. """ rank_zero_deprecation( "`Accelerator.on_test_end` is deprecated in v1.5 and will be removed in v1.6. " @@ -662,8 +655,11 @@ def on_test_end(self) -> None: def on_predict_end(self) -> None: """Called when predict ends. - .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - `training_type_plugin.on_predict_end` directly. + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_predict_end` directly. """ rank_zero_deprecation( "`Accelerator.on_predict_end` is deprecated in v1.5 and will be removed in v1.6. " @@ -674,8 +670,11 @@ def on_predict_end(self) -> None: def on_train_end(self) -> None: """Called when train ends. - .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - `training_type_plugin.on_train_end` directly. + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_train_end` directly. """ rank_zero_deprecation( "`Accelerator.on_train_end` is deprecated in v1.5 and will be removed in v1.6. " @@ -687,8 +686,11 @@ def on_train_end(self) -> None: def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Called in the training loop before anything happens for that batch. - .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - `training_type_plugin.on_train_batch_start` directly. + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_train_batch_start` directly. """ rank_zero_deprecation( "`Accelerator.on_train_batch_start` is deprecated in v1.5 and will be removed in v1.6. " diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e2336c56e36d6..610f512324b82 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1033,7 +1033,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # SET UP TRAINING # ---------------------------- self.call_hook("on_before_accelerator_backend_setup") - self.training_type_plugin.setup_environment() + self.accelerator.setup_environment() self._call_setup_hook() # allow user to setup lightning_module in accelerator environment # check if we should delay restoring checkpoint till later @@ -1166,7 +1166,7 @@ def _post_dispatch(self): self.accelerator.post_dispatch(self) # these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns # which need to happen before. - self.training_type_plugin.teardown() + self.accelerator.teardown() self.data_connector.teardown() self._active_loop.teardown() self.logger_connector.teardown() diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 6a54026724b68..d8b05aff5fe43 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -330,7 +330,7 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import(): from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401 -def test_v1_6_0_deprecated_accelerator_pass_through_functions(tmpdir): +def test_v1_6_0_deprecated_accelerator_pass_through_functions(): from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type import SingleDevicePlugin @@ -352,9 +352,6 @@ def test_v1_6_0_deprecated_accelerator_pass_through_functions(tmpdir): model = BoringModel() accelerator.connect(model) - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.setup_environment() - with pytest.deprecated_call(match="will be removed in v1.6"): accelerator.teardown() @@ -389,9 +386,6 @@ def test_v1_6_0_deprecated_accelerator_pass_through_functions(tmpdir): with pytest.deprecated_call(match="will be removed in v1.6"): accelerator.restore_checkpoint_after_pre_dispatch - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.on_train_start() - with pytest.deprecated_call(match="will be removed in v1.6"): accelerator.on_validation_start() diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 6aeb79f80b255..332c4e8d69d60 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -24,7 +24,7 @@ def test_invalid_on_cpu(tmpdir): ): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins="fsdp") assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin) - trainer.training_type_plugin.setup_environment() + trainer.accelerator.setup_environment() @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index 618e191dc8ee6..6c75a12e76bfc 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -109,7 +109,7 @@ def test_ddp_configure_ddp(): # test wrap the model if fitting trainer.state.fn = TrainerFn.FITTING trainer.training_type_plugin.connect(model) - trainer.training_type_plugin.setup_environment() + trainer.accelerator.setup_environment() trainer.accelerator.setup(trainer) trainer.lightning_module.trainer = trainer assert isinstance(trainer.model, LightningModule) @@ -123,7 +123,7 @@ def test_ddp_configure_ddp(): ) # test do not wrap the model if trainerFN is not fitting trainer.training_type_plugin.connect(model) - trainer.training_type_plugin.setup_environment() + trainer.accelerator.setup_environment() trainer.accelerator.setup(trainer) trainer.lightning_module.trainer = trainer trainer._pre_dispatch() From 8c91d45772620808796d477d141d2f66200c9098 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Tue, 12 Oct 2021 12:54:17 -0700 Subject: [PATCH 5/7] fix comment format --- tests/deprecated_api/test_remove_1-6.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index d8b05aff5fe43..a5101d3311bf3 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -352,9 +352,6 @@ def test_v1_6_0_deprecated_accelerator_pass_through_functions(): model = BoringModel() accelerator.connect(model) - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.teardown() - with pytest.deprecated_call(match="will be removed in v1.6"): accelerator.post_training_step() From 26823edc591c900cab51987300fd4b6f1379cd6f Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Tue, 12 Oct 2021 13:29:28 -0700 Subject: [PATCH 6/7] fix comment format --- pytorch_lightning/accelerators/accelerator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 1dc37349d0707..541cf5de3be2b 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -63,8 +63,9 @@ def connect(self, model: "pl.LightningModule") -> None: See deprecation warning below. - .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call - `training_type_plugin.connect` directly. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_train_batch_start` directly. """ rank_zero_deprecation( "`Accelerator.connect` is deprecated in v1.5 and will be removed in v1.6. " From 08a8785adf853563a9fdf5926a024673d5da02fb Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Wed, 13 Oct 2021 16:36:48 -0700 Subject: [PATCH 7/7] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c5036b693349..269f8e5f27a4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -288,7 +288,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896)) -- Changed `training_type_plugin` in favor of `accelerator` within the loops ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901)) +- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901)) ### Deprecated