Skip to content
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).



- Deprecated `add_to_queue`, `get_from_queue` from Lightning Module ([#9126](https://github.com/PyTorchLightning/pytorch-lightning/pull/9126))

### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,9 @@ def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:

Args:
queue: the instance of the queue to append the data.

.. deprecated:: v1.5
This method was deprecated in v1.5 and will be removed in v1.7.
"""
callback_metrics: dict = apply_to_collection(
self.trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
Expand All @@ -1973,6 +1976,9 @@ def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:

Args:
queue: the instance of the queue from where to get the data.

.. deprecated:: v1.5
This method was deprecated in v1.5 and will be removed in v1.7.
"""
# NOTE: `add_to_queue` needs to be called before
callback_metrics: dict = queue.get()
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,18 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()

# TODO(@daniellepintz): add trainer argument in v1.7
def post_dispatch(self):
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
self._results = self.mp_queue.get()
# TODO(@daniellepintz): add `trainer.callback_metrics = self.mp_queue.get()` in v1.7

# get the `callback_metrics` and set it to the trainer
# only in case the user does not override it.

# TODO(@daniellepintz): remove in v1.7
self.lightning_module.get_from_queue(self.mp_queue)

# recover the weights of the processes trained in the children
Expand Down Expand Up @@ -286,6 +291,9 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
# TODO(@daniellepintz): add `self.mp_queue.put(trainer.callback_metrics)` in v1.7

# TODO(@daniellepintz): remove in v1.7
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue

def __recover_child_process_weights(self, best_path, last_path):
Expand Down
20 changes: 19 additions & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -43,6 +43,7 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
elif self.trainer.state.fn == TrainerFn.PREDICTING:
self.__verify_predict_loop_configuration(model)
self.__verify_dp_batch_transfer_support(model)
self._check_add_get_queue(model)

def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None:
# -----------------------------------
Expand Down Expand Up @@ -153,3 +154,20 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod
"The model taking a `dataloader_iter` argument in your `training_step` "
"is incompatible with `truncated_bptt_steps > 0`."
)

def _check_add_get_queue(self, model: "pl.LightningModule"):
r"""
Checks if add_to_queue or get_from_queue is overriden and sends a deprecation warning.

Args:
model: The lightning module

"""
if is_overridden("add_to_queue", model):
rank_zero_deprecation(
"The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7."
)
if is_overridden("get_from_queue", model):
rank_zero_deprecation(
"The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7."
)
11 changes: 11 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
from tests.helpers.runif import RunIf


def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir):
Expand Down Expand Up @@ -91,6 +92,16 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
_ = Trainer(prepare_data_per_node=False)


@RunIf(min_gpus=2)
def test_v1_7_0_deprecate_add_get_queue(tmpdir):
"""Tests if device is set correctly when training for DDPSpawnPlugin."""
with pytest.deprecated_call(match=r"`LightningModule.add_to_queue` method was deprecated in v1.5"):
_ = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn")

with pytest.deprecated_call(match=r"`LightningModule.get_from_queue` method was deprecated in v1.5"):
_ = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn")


@mock.patch("pytorch_lightning.loggers.test_tube.Experiment")
def test_v1_7_0_test_tube_logger(_, tmpdir):
with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"):
Expand Down
3 changes: 2 additions & 1 deletion tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_ddp_cpu():

@RunIf(min_gpus=2)
def test_ddp_spawn_extra_parameters(tmpdir):
"""Tests if device is set correctely when training for DDPSpawnPlugin."""
"""Tests if device is set correctly when training for DDPSpawnPlugin."""
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn")

assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
Expand All @@ -76,4 +76,5 @@ def test_ddp_spawn_extra_parameters(tmpdir):

trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
# TODO(@daniellepintz) remove assert in v1.7
assert model.test_val == "test_val"