diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f12dda513629..269f8e5f27a4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896)) +- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901)) + + ### Deprecated - Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index cfed45e1db186..541cf5de3be2b 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -59,7 +59,18 @@ def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: Trai self.optimizer_frequencies: List = [] def connect(self, model: "pl.LightningModule") -> None: - """Transfers ownership of the model to this plugin.""" + """Transfers ownership of the model to this plugin. + + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_train_batch_start` directly. + """ + rank_zero_deprecation( + "`Accelerator.connect` is deprecated in v1.5 and will be removed in v1.6. " + "`connect` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.connect(model) def setup_environment(self) -> None: @@ -82,12 +93,39 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_precision_plugin() def start_training(self, trainer: "pl.Trainer") -> None: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.start_training` directly. + """ + rank_zero_deprecation( + "`Accelerator.start_training` is deprecated in v1.5 and will be removed in v1.6. " + "`start_training` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.start_training(trainer) def start_evaluating(self, trainer: "pl.Trainer") -> None: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.start_evaluating` directly. + """ + rank_zero_deprecation( + "`Accelerator.start_evaluating` is deprecated in v1.5 and will be removed in v1.6. " + "`start_evaluating` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.start_evaluating(trainer) def start_predicting(self, trainer: "pl.Trainer") -> None: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.start_predicting` directly. + """ + rank_zero_deprecation( + "`Accelerator.start_predicting` is deprecated in v1.5 and will be removed in v1.6. " + "`start_predicting` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.start_predicting(trainer) def pre_dispatch(self, trainer: "pl.Trainer") -> None: @@ -177,6 +215,15 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: return self.training_type_plugin.training_step(*step_kwargs.values()) def post_training_step(self) -> None: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.post_training_step` directly. + """ + rank_zero_deprecation( + "`Accelerator.post_training_step` is deprecated in v1.5 and will be removed in v1.6. " + "`post_training_step` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.post_training_step() def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -206,25 +253,49 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: """A hook to do something at the end of the training step. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.training_step_end` directly. + Args: output: the output of the training step """ + rank_zero_deprecation( + "`Accelerator.training_step_end` is deprecated in v1.5 and will be removed in v1.6. " + "`training_step_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.training_step_end(output) def test_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: """A hook to do something at the end of the test step. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.test_step_end` directly. + Args: output: the output of the test step """ + rank_zero_deprecation( + "`Accelerator.test_step_end` is deprecated in v1.5 and will be removed in v1.6. " + "`test_step_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.test_step_end(output) def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: """A hook to do something at the end of the validation step. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.validation_step_end` directly. + Args: output: the output of the validation step """ + rank_zero_deprecation( + "`Accelerator.validation_step_end` is deprecated in v1.5 and will be removed in v1.6. " + "`validation_step_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.validation_step_end(output) def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: @@ -330,19 +401,27 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: """Returns state of model. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.lightning_module_state_dict` directly. + Allows for syncing/collating model state from processes in custom plugins. """ + rank_zero_deprecation( + "`Accelerator.lightning_module_state_dict` is deprecated in v1.5 and will be removed in v1.6. " + "`lightning_module_state_dict` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.lightning_module_state_dict() def barrier(self, name: Optional[str] = None) -> None: """ .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.barrier`` directly. + Please call `training_type_plugin.barrier` directly. """ rank_zero_deprecation( "`Accelerator.barrier` is deprecated in v1.5 and will be removed in v1.6. " - "Barrier logic is implemented directly in the `TrainingTypePlugin` implementations." + "`Barrier` logic is implemented directly in the `TrainingTypePlugin` implementations." ) self.training_type_plugin.barrier(name=name) @@ -352,7 +431,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.broadcast`` directly. + Please call `training_type_plugin.broadcast` directly. Args: obj: Object to broadcast to all process, usually a tensor or collection of tensors. @@ -360,7 +439,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: """ rank_zero_deprecation( "`Accelerator.broadcast` is deprecated in v1.5 and will be removed in v1.6. " - "Broadcast logic is implemented directly in the `TrainingTypePlugin` implementations." + "`Broadcast` logic is implemented directly in the `TrainingTypePlugin` implementations." ) return self.training_type_plugin.broadcast(obj, src) @@ -369,7 +448,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. - Please call ``training_type_plugin.all_gather`` directly. + Please call `training_type_plugin.all_gather` directly. Args: tensor: tensor of shape (batch, ...) @@ -381,24 +460,40 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo """ rank_zero_deprecation( "`Accelerator.all_gather` is deprecated in v1.5 and will be removed in v1.6. " - "All-gather logic is implemented directly in the `TrainingTypePlugin` implementations." + "`all_gather` logic is implemented directly in the `TrainingTypePlugin` implementations." ) return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: """Wraps the dataloader if necessary. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.process_dataloader` directly. + Args: dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` """ + rank_zero_deprecation( + "`Accelerator.process_dataloader` is deprecated in v1.5 and will be removed in v1.6. " + "`process_dataloader` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.process_dataloader(dataloader) @property def results(self) -> Any: """The results of the last run will be cached within the training type plugin. + .. deprecated:: v1.5 + This property is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.results` directly. + In distributed training, we make sure to transfer the results to the appropriate master process. """ + rank_zero_deprecation( + "`Accelerator.results` is deprecated in v1.5 and will be removed in v1.6. " + "Accesse results directly from the `TrainingTypePlugin`." + ) return self.training_type_plugin.results @contextlib.contextmanager @@ -417,10 +512,18 @@ def model_sharded_context(self) -> Generator[None, None, None]: def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.save_checkpoint` directly. + Args: checkpoint: dict containing model and trainer state filepath: write-target file's path """ + rank_zero_deprecation( + "`Accelerator.save_checkpoint` is deprecated in v1.5 and will be removed in v1.6. " + "`save_checkpoint` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) self.training_type_plugin.save_checkpoint(checkpoint, filepath) @property @@ -429,9 +532,17 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: `TrainingTypePlugin` requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set. + .. deprecated:: v1.5 + This property is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.setup_optimizers_in_pre_dispatch` directly. + Returns: If True, delay setup optimizers until `pre_dispatch`, else call within `setup`. """ + rank_zero_deprecation( + "`Accelerator.setup_optimizers_in_pre_dispatch` is deprecated in v1.5 and will be removed in v1.6. " + "Accesse `setup_optimizers_in_pre_dispatch directly` from the `TrainingTypePlugin`." + ) return self.training_type_plugin.setup_optimizers_in_pre_dispatch @property @@ -439,9 +550,17 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool: """Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint. + .. deprecated:: v1.5 + This property is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.restore_checkpoint_after_pre_dispatch` directly. + Returns: If true, restore checkpoint after pre_dispatch. """ + rank_zero_deprecation( + "`Accelerator.restore_checkpoint_after_pre_dispatch` is deprecated in v1.5 and will be removed in v1.6. " + "Accesse `restore_checkpoint_after_pre_dispatch` directly from the `TrainingTypePlugin`." + ) return self.training_type_plugin.restore_checkpoint_after_pre_dispatch def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @@ -460,34 +579,122 @@ def on_train_start(self) -> None: return self.training_type_plugin.on_train_start() def on_validation_start(self) -> None: - """Called when validation begins.""" + """Called when validation begins. + + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_validation_start` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_validation_start` is deprecated in v1.5 and will be removed in v1.6. " + "`on_validation_start` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_validation_start() def on_test_start(self) -> None: - """Called when test begins.""" + """Called when test begins. + + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_test_start` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_test_start` is deprecated in v1.5 and will be removed in v1.6. " + "`on_test_start` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_test_start() def on_predict_start(self) -> None: - """Called when predict begins.""" + """Called when predict begins. + + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_predict_start` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_predict_start` is deprecated in v1.5 and will be removed in v1.6. " + "`on_predict_start` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_predict_start() def on_validation_end(self) -> None: - """Called when validation ends.""" + """Called when validation ends. + + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_validation_end` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_validation_end` is deprecated in v1.5 and will be removed in v1.6. " + "`on_validation_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_validation_end() def on_test_end(self) -> None: - """Called when test end.""" + """Called when test end. + + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_test_end` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_test_end` is deprecated in v1.5 and will be removed in v1.6. " + "`on_test_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_test_end() def on_predict_end(self) -> None: - """Called when predict ends.""" + """Called when predict ends. + + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_predict_end` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_predict_end` is deprecated in v1.5 and will be removed in v1.6. " + "`on_predict_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_predict_end() def on_train_end(self) -> None: - """Called when train ends.""" + """Called when train ends. + + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_train_end` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_train_end` is deprecated in v1.5 and will be removed in v1.6. " + "`on_train_end` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_train_end() # TODO: Update this in v1.7 (deprecation: #9816) def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - """Called in the training loop before anything happens for that batch.""" + """Called in the training loop before anything happens for that batch. + + See deprecation warning below. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.6. + Please call `training_type_plugin.on_train_batch_start` directly. + """ + rank_zero_deprecation( + "`Accelerator.on_train_batch_start` is deprecated in v1.5 and will be removed in v1.6. " + "`on_train_batch_start` logic is implemented directly in the `TrainingTypePlugin` implementations." + ) return self.training_type_plugin.on_train_batch_start(batch, batch_idx) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 92c7d36cfd0d8..119e7f6c5472c 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -100,7 +100,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: void(*args, **kwargs) dataloader_idx: int = self.current_dataloader_idx - dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) + dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader) dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx) dl_max_batches = self._max_batches[dataloader_idx] diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index d4a6ab6d29cef..cf40316312107 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -84,7 +84,7 @@ def on_run_start(self) -> None: def advance(self, *args: Any, **kwargs: Any) -> None: """Predicts one entire dataloader.""" void(*args, **kwargs) - dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) + dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader) dataloader_iter = enumerate(dataloader) dl_max_batches = self.max_batches[self.current_dataloader_idx] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 0d16c978fb374..6c692bb2dc6eb 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -204,7 +204,7 @@ def on_advance_start(self) -> None: def advance(self) -> None: """Runs one whole epoch.""" - dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) + dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader) data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader) with self.trainer.profiler.profile("run_training_epoch"): @@ -234,7 +234,7 @@ def on_run_end(self) -> None: self.trainer.call_hook("on_train_end") # give accelerators a chance to finish - self.trainer.accelerator.on_train_end() + self.trainer.training_type_plugin.on_train_end() def teardown(self) -> None: self.epoch_loop.teardown() diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 21ad8738ab5bb..4c8bf157d331b 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -107,7 +107,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] lightning_module._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() + self.trainer.training_type_plugin.post_training_step() del step_kwargs diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 3a73795014c80..4ad85cc4650f2 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -450,7 +450,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos lightning_module._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() + self.trainer.training_type_plugin.post_training_step() del step_kwargs diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index d48ab3f518443..da6a81e8add44 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -465,7 +465,7 @@ def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: weights_only: saving model weights only """ _checkpoint = self.dump_checkpoint(weights_only) - self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) + self.trainer.training_type_plugin.save_checkpoint(_checkpoint, filepath) def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: metrics = ( @@ -478,7 +478,7 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: metric.persistent(True) metric.sync() - state_dict = self.trainer.accelerator.lightning_module_state_dict() + state_dict = self.trainer.training_type_plugin.lightning_module_state_dict() for metric in metrics: # sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 01acae35fd46c..610f512324b82 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1020,13 +1020,13 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.callback_connector.attach_model_logging_functions(model) # attach model to the training type plugin - self.accelerator.connect(model) + self.training_type_plugin.connect(model) # hook self.data_connector.prepare_data() self.callback_connector._attach_model_callbacks() - if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch: + if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch: self._load_checkpoint_weights() # ---------------------------- @@ -1037,7 +1037,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self._call_setup_hook() # allow user to setup lightning_module in accelerator environment # check if we should delay restoring checkpoint till later - if not self.accelerator.restore_checkpoint_after_pre_dispatch: + if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch: self.checkpoint_connector.resume_start() self._restore_modules_and_callbacks() @@ -1055,9 +1055,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, | || {self._dispatch} || | || LIGHTNING - {self.accelerator.start_training} || - or {self.accelerator.start_evaluating} || - or {self.accelerator.start_predicting} || FLOW + {self.training_type_plugin.start_training} || + or {self.training_type_plugin.start_evaluating} || + or {self.training_type_plugin.start_predicting} || FLOW | || {self.run_stage} || | || DIRECTION @@ -1087,7 +1087,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() - if self.accelerator.restore_checkpoint_after_pre_dispatch: + if self.training_type_plugin.restore_checkpoint_after_pre_dispatch: if self._ckpt_path: self._load_checkpoint_weights() @@ -1119,7 +1119,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.state.status = TrainerStatus.FINISHED self.state.stage = None - return self.accelerator.results + return self.training_type_plugin.results def _pre_dispatch(self): self.accelerator.pre_dispatch(self) @@ -1173,11 +1173,11 @@ def _post_dispatch(self): def _dispatch(self): if self.evaluating: - self.accelerator.start_evaluating(self) + self.training_type_plugin.start_evaluating(self) elif self.predicting: - self.accelerator.start_predicting(self) + self.training_type_plugin.start_predicting(self) else: - self.accelerator.start_training(self) + self.training_type_plugin.start_training(self) def run_stage(self): self.accelerator.dispatch(self) @@ -1509,22 +1509,26 @@ def precision_plugin(self) -> PrecisionPlugin: @property def global_rank(self) -> int: - return self.accelerator.training_type_plugin.global_rank + return self.training_type_plugin.global_rank @property def local_rank(self) -> int: # some training types define a local rank - return getattr(self.accelerator.training_type_plugin, "local_rank", 0) + return getattr(self.training_type_plugin, "local_rank", 0) @property def node_rank(self) -> int: # some training types define a local rank - return getattr(self.accelerator.training_type_plugin, "node_rank", 0) + return getattr(self.training_type_plugin, "node_rank", 0) @property def world_size(self) -> int: # some training types define a world size - return getattr(self.accelerator.training_type_plugin, "world_size", 1) + return getattr(self.training_type_plugin, "world_size", 1) + + @property + def should_rank_save_checkpoint(self) -> bool: + return self.training_type_plugin.should_rank_save_checkpoint @property def _distrib_type(self) -> DistributedType: diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index f95d182f9e5e1..697fae1644b1b 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -43,7 +43,7 @@ def test_restore_checkpoint_after_pre_dispatch_default(): """Assert default for restore_checkpoint_after_pre_dispatch is False.""" plugin = SingleDevicePlugin(torch.device("cpu")) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) - assert not accelerator.restore_checkpoint_after_pre_dispatch + assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch assert not plugin.restore_checkpoint_after_pre_dispatch @@ -77,7 +77,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO()) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) - assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch + assert accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch trainer = Trainer( diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 546b25a54aa83..a5101d3311bf3 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -330,7 +330,7 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import(): from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401 -def test_v1_6_0_deprecated_accelerator_collective(): +def test_v1_6_0_deprecated_accelerator_pass_through_functions(): from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type import SingleDevicePlugin @@ -347,3 +347,62 @@ def test_v1_6_0_deprecated_accelerator_collective(): with pytest.deprecated_call(match="will be removed in v1.6"): tensor = torch.rand(2, 2, requires_grad=True) accelerator.all_gather(tensor) + + with pytest.deprecated_call(match="will be removed in v1.6"): + model = BoringModel() + accelerator.connect(model) + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.post_training_step() + + with pytest.deprecated_call(match="will be removed in v1.6"): + tensor = torch.rand(2, 2, requires_grad=True) + accelerator.training_step_end(tensor) + + with pytest.deprecated_call(match="will be removed in v1.6"): + tensor = torch.rand(2, 2, requires_grad=True) + accelerator.test_step_end(tensor) + + with pytest.deprecated_call(match="will be removed in v1.6"): + tensor = torch.rand(2, 2, requires_grad=True) + accelerator.validation_step_end(tensor) + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.lightning_module_state_dict() + + with pytest.deprecated_call(match="will be removed in v1.6"): + dl = model.train_dataloader() + accelerator.process_dataloader(dl) + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.results + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.setup_optimizers_in_pre_dispatch + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.restore_checkpoint_after_pre_dispatch + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_validation_start() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_test_start() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_predict_start() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_validation_end() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_test_end() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_predict_end() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_train_end() + + with pytest.deprecated_call(match="will be removed in v1.6"): + accelerator.on_train_batch_start(batch=None, batch_idx=0) diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 61688b8847778..332c4e8d69d60 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -120,7 +120,7 @@ def test_fully_sharded_plugin_checkpoint_multi_gpus(tmpdir): def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): # Use FullySharded to get the state dict for the sake of comparison - model_state_dict = trainer.accelerator.lightning_module_state_dict() + model_state_dict = trainer.training_type_plugin.lightning_module_state_dict() if trainer.is_global_zero: saved_model = cls.load_from_checkpoint(ckpt_path) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index 03cc0e1ff7beb..6c75a12e76bfc 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -108,7 +108,7 @@ def test_ddp_configure_ddp(): ) # test wrap the model if fitting trainer.state.fn = TrainerFn.FITTING - trainer.accelerator.connect(model) + trainer.training_type_plugin.connect(model) trainer.accelerator.setup_environment() trainer.accelerator.setup(trainer) trainer.lightning_module.trainer = trainer @@ -122,7 +122,7 @@ def test_ddp_configure_ddp(): plugins=[ddp_plugin], ) # test do not wrap the model if trainerFN is not fitting - trainer.accelerator.connect(model) + trainer.training_type_plugin.connect(model) trainer.accelerator.setup_environment() trainer.accelerator.setup(trainer) trainer.lightning_module.trainer = trainer