Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 17 additions & 28 deletions src/lightning/pytorch/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import logging
import os
import tempfile
from collections import UserList
from contextlib import suppress
from dataclasses import dataclass
from multiprocessing.queues import SimpleQueue
Expand Down Expand Up @@ -170,7 +169,7 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train
trainer.state = worker_output.trainer_state

# get the `callback_metrics` and set it to the trainer
self.get_from_queue(trainer, worker_output.extra)
self.update_main_process_results(trainer, worker_output.extra)

def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
rank_zero_debug("Collecting results from rank 0 process.")
Expand All @@ -194,9 +193,8 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
weights_path = os.path.join(tempfile.mkdtemp(), ".temp.ckpt")
self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)

# adds the `callback_metrics` to the queue
extra = _FakeQueue()
self.add_to_queue(trainer, extra)
# add extra result data from trainer to send to main process
extra = self.get_extra_results(trainer)

return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)

Expand All @@ -210,29 +208,33 @@ def _check_torchdistx_support(self) -> None:
f" initialization when `start_method='spawn'`."
)

def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.
def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
"""Gather extra state from the Trainer and return it as a dictionary for sending back to the main process.
To avoid issues with memory sharing, we cast the data to numpy.

Args:
trainer: reference to the Trainer.
queue: the instance of the queue to append the data.

Returns:
A dictionary with items to send back to the main process where :meth:`update_main_process_results` will
process this output.
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
queue.put(callback_metrics)
return {"callback_metrics": callback_metrics}

def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.

Args:
trainer: reference to the Trainer.
queue: the instance of the queue from where to get the data.
extra: A dictionary with trainer state that was sent from the worker process and needs to be restored
on the current trainer.
"""
# NOTE: `add_to_queue` needs to be called before
callback_metrics: dict = queue.get()
# NOTE: `get_extra_results` needs to be called before
callback_metrics = extra["callback_metrics"]
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))

def kill(self, signum: _SIGNUM) -> None:
Expand All @@ -248,25 +250,12 @@ def __getstate__(self) -> Dict:
return state


class _FakeQueue(UserList):
"""Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list."""

def get(self) -> Any:
return self.pop(0)

def put(self, item: Any) -> None:
self.append(item)

def empty(self) -> bool:
return len(self) == 0


class _WorkerOutput(NamedTuple):
best_model_path: Optional[_PATH]
weights_path: Optional[_PATH]
trainer_state: TrainerState
trainer_results: Any
extra: _FakeQueue
extra: Dict[str, Any]


@dataclass
Expand Down
6 changes: 2 additions & 4 deletions src/lightning/pytorch/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from lightning.fabric.strategies.launchers.xla import _rank_teardown
from lightning.fabric.utilities import move_data_to_device
from lightning.pytorch.strategies.launchers.multiprocessing import (
_FakeQueue,
_GlobalStateSnapshot,
_MultiProcessingLauncher,
_WorkerOutput,
Expand Down Expand Up @@ -137,8 +136,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
if self._strategy.local_rank != 0:
return None

# adds the `callback_metrics` to the queue
extra = _FakeQueue()
self.add_to_queue(trainer, extra)
# add extra result data from trainer to send to main process
extra = self.get_extra_results(trainer)

return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)
15 changes: 8 additions & 7 deletions tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ def test_ddp_cpu():


class CustomMultiProcessingLauncher(_MultiProcessingLauncher):
def add_to_queue(self, trainer, queue) -> None:
queue.put("test_val")
return super().add_to_queue(trainer, queue)
def get_extra_results(self, trainer):
extra = super().get_extra_results(trainer)
extra["test_val"] = "test_val"
return extra

def get_from_queue(self, trainer: Trainer, queue) -> None:
trainer.strategy.test_val = queue.get()
return super().get_from_queue(trainer, queue)
def update_main_process_results(self, trainer, extra) -> None:
trainer.strategy.test_val = extra.pop("test_val")
return super().update_main_process_results(trainer, extra)


class TestDDPSpawnStrategy(DDPStrategy):
Expand All @@ -71,7 +72,7 @@ def _configure_launcher(self):

@RunIf(skip_windows=True)
def test_ddp_spawn_add_get_queue(tmpdir):
"""Tests add_to_queue/get_from_queue with DDPStrategy."""
"""Tests get_extra_results/update_main_process_results with DDPSpawnStrategy."""

ddp_spawn_strategy = TestDDPSpawnStrategy()
trainer = Trainer(
Expand Down