diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fab74f8b2135..b92d91a377574 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -162,6 +162,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +- Deprecated `add_to_queue`, `get_from_queue` from Lightning Module ([#9126](https://github.com/PyTorchLightning/pytorch-lightning/pull/9126)) + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 096333388c3b1..1c0d27c4c8ce1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1960,6 +1960,9 @@ def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: Args: queue: the instance of the queue to append the data. + + .. deprecated:: v1.5 + This method was deprecated in v1.5 and will be removed in v1.7. """ callback_metrics: dict = apply_to_collection( self.trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() @@ -1973,6 +1976,9 @@ def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: Args: queue: the instance of the queue from where to get the data. + + .. deprecated:: v1.5 + This method was deprecated in v1.5 and will be removed in v1.7. """ # NOTE: `add_to_queue` needs to be called before callback_metrics: dict = queue.get() diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index c31a908902a27..6ba53ce168dff 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -213,13 +213,18 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() + # TODO(@daniellepintz): add trainer argument in v1.7 def post_dispatch(self): # restore main state with best weights best_path = self.mp_queue.get() last_path = self.mp_queue.get() self._results = self.mp_queue.get() + # TODO(@daniellepintz): add `trainer.callback_metrics = self.mp_queue.get()` in v1.7 + # get the `callback_metrics` and set it to the trainer # only in case the user does not override it. + + # TODO(@daniellepintz): remove in v1.7 self.lightning_module.get_from_queue(self.mp_queue) # recover the weights of the processes trained in the children @@ -286,6 +291,9 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul self.mp_queue.put(best_model_path) self.mp_queue.put(last_path) self.mp_queue.put(results) + # TODO(@daniellepintz): add `self.mp_queue.put(trainer.callback_metrics)` in v1.7 + + # TODO(@daniellepintz): remove in v1.7 self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue def __recover_child_process_weights(self, best_path, last_path): diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index d9c341c5dfaeb..217e489109bd7 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -13,7 +13,7 @@ # limitations under the License. import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -43,6 +43,7 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: elif self.trainer.state.fn == TrainerFn.PREDICTING: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) + self._check_add_get_queue(model) def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None: # ----------------------------------- @@ -153,3 +154,20 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod "The model taking a `dataloader_iter` argument in your `training_step` " "is incompatible with `truncated_bptt_steps > 0`." ) + + def _check_add_get_queue(self, model: "pl.LightningModule"): + r""" + Checks if add_to_queue or get_from_queue is overriden and sends a deprecation warning. + + Args: + model: The lightning module + + """ + if is_overridden("add_to_queue", model): + rank_zero_deprecation( + "The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7." + ) + if is_overridden("get_from_queue", model): + rank_zero_deprecation( + "The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7." + ) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index ae8f9e1dcc53d..25be14252d11b 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -21,6 +21,7 @@ from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule +from tests.helpers.runif import RunIf def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir): @@ -91,6 +92,16 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): _ = Trainer(prepare_data_per_node=False) +@RunIf(min_gpus=2) +def test_v1_7_0_deprecate_add_get_queue(tmpdir): + """Tests if device is set correctly when training for DDPSpawnPlugin.""" + with pytest.deprecated_call(match=r"`LightningModule.add_to_queue` method was deprecated in v1.5"): + _ = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn") + + with pytest.deprecated_call(match=r"`LightningModule.get_from_queue` method was deprecated in v1.5"): + _ = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn") + + @mock.patch("pytorch_lightning.loggers.test_tube.Experiment") def test_v1_7_0_test_tube_logger(_, tmpdir): with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"): diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 1ab94446c8176..e9a8f1aaf839e 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -62,7 +62,7 @@ def test_ddp_cpu(): @RunIf(min_gpus=2) def test_ddp_spawn_extra_parameters(tmpdir): - """Tests if device is set correctely when training for DDPSpawnPlugin.""" + """Tests if device is set correctly when training for DDPSpawnPlugin.""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn") assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) @@ -76,4 +76,5 @@ def test_ddp_spawn_extra_parameters(tmpdir): trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) + # TODO(@daniellepintz) remove assert in v1.7 assert model.test_val == "test_val"