Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
d05363f
improve spawn queue
awaelchli Oct 20, 2021
d650e26
clean up
awaelchli Oct 20, 2021
5fda23a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2021
d6b4a34
Merge branch 'master' into feature/simple-spawn
awaelchli Nov 30, 2021
bcfb853
fix
awaelchli Nov 30, 2021
97b4bf6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
38b3a54
rename
awaelchli Nov 30, 2021
955b6c8
delete dead code
awaelchli Nov 30, 2021
13393e8
Merge remote-tracking branch 'origin/feature/simple-spawn' into featu…
awaelchli Nov 30, 2021
f3216b2
clean up
awaelchli Nov 30, 2021
2d00231
update lite
awaelchli Nov 30, 2021
7aa3646
retain the queue interface in hooks
awaelchli Nov 30, 2021
fb0c0d8
update tests
awaelchli Nov 30, 2021
1bc59ae
Merge branch 'master' into feature/simple-spawn
awaelchli Nov 30, 2021
7e6c75e
_notebooks
awaelchli Nov 30, 2021
b7efc50
reset notebooks
awaelchli Nov 30, 2021
84ca8b4
avoid circular import
awaelchli Nov 30, 2021
965c724
fix unused imports
awaelchli Nov 30, 2021
1aae8dd
reset debugging script
awaelchli Nov 30, 2021
4b998db
typing _ExtraQueue
awaelchli Nov 30, 2021
5871a4b
bring changes to tpu_spawn plugin
awaelchli Nov 30, 2021
aa76840
unify
awaelchli Nov 30, 2021
37f9db9
remove dead code
awaelchli Nov 30, 2021
d68cb35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
dd80be9
remove queue from tpu spawn
awaelchli Nov 30, 2021
f97eee8
type annotation for new_process
awaelchli Nov 30, 2021
ad61d27
Merge remote-tracking branch 'origin/feature/simple-spawn' into refac…
awaelchli Nov 30, 2021
459121e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
72535ff
unused imports
awaelchli Nov 30, 2021
3095da9
Merge remote-tracking branch 'origin/feature/simple-spawn' into refac…
awaelchli Nov 30, 2021
61192df
move check
awaelchli Nov 30, 2021
801f529
revert
awaelchli Nov 30, 2021
1cd258b
collect results on tpu
awaelchli Nov 30, 2021
ae6019e
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Nov 30, 2021
10ecbfd
rename
awaelchli Nov 30, 2021
ebba63f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
d7df4d9
fix merge errors
awaelchli Nov 30, 2021
4c547aa
fix merge errors
awaelchli Nov 30, 2021
e4e2a77
re-add clean_logger
awaelchli Dec 1, 2021
86e43b2
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Dec 1, 2021
acac29d
fix typing
awaelchli Dec 1, 2021
0ae457a
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Dec 1, 2021
880c8fc
changelog entries
awaelchli Dec 1, 2021
5eeb02a
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Dec 1, 2021
7520adc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2021
96f2749
rename _ExtraQueue -> _FakeQueue
awaelchli Dec 1, 2021
65d183c
missing typing updates
awaelchli Dec 1, 2021
8c4e2e4
Introducing NamedTuple for spawn output typing
awaelchli Dec 1, 2021
213b447
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2021
7521ec5
Merge branch 'master' into feature/simple-spawn
tchaton Dec 2, 2021
2579247
inherit from UserList
awaelchli Dec 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649))


-
- The `DDPSpawnPlugin` no longer overrides the `post_dispatch` plugin hook ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))


- The `LightningModule.{add_to_queue,get_from_queue}` hooks no longer get a `torch.multiprocessing.SimpleQueue` and instead receive a list based queue ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))


### Deprecated

Expand Down Expand Up @@ -188,6 +192,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))


- Removed the property `TrainingTypePlugin.results` and corresponding properties in subclasses ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))


- Removed the `mp_queue` attribute from `DDPSpawnPlugin` and `TPUSpawnPlugin` ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))


- Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849))


Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,7 +1917,7 @@ def model_size(self) -> float:
)
return get_model_size_mb(self)

def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.

Expand All @@ -1931,7 +1931,7 @@ def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
self.trainer.training_type_plugin.add_to_queue(self.trainer, queue)

def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.

Expand Down
124 changes: 65 additions & 59 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import logging
import os
import re
from collections import UserList
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -45,7 +46,7 @@
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
Expand Down Expand Up @@ -80,7 +81,6 @@ def __init__(
self.sync_batchnorm = False
self._ddp_kwargs = kwargs
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.mp_queue = None
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
Expand All @@ -101,15 +101,6 @@ def num_nodes(self, num_nodes: int) -> None:
def local_rank(self) -> int:
return self._local_rank

def __getstate__(self):
"""Makes this plugin pickleable without destroying the queue in the current process."""
state = self.__dict__.copy()
state["mp_queue"] = None
return state

def __setstate__(self, state):
self.__dict__ = state

@property
def root_device(self):
return self.parallel_devices[self.local_rank]
Expand All @@ -125,9 +116,6 @@ def _is_single_process_single_device(self):

def setup(self, trainer: "pl.Trainer") -> None:
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
# pass in a state q
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()
super().setup(trainer)

def _setup_model(self, model: Module) -> DistributedDataParallel:
Expand All @@ -145,18 +133,24 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {"nprocs": self.num_processes}

def start_training(self, trainer: "pl.Trainer") -> None:
self.spawn(self.new_process, trainer, self.mp_queue)
def start_training(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self.__recover_results_in_main_process(spawn_output, trainer)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []
return spawn_output.trainer_results

def start_evaluating(self, trainer: "pl.Trainer") -> None:
self.spawn(self.new_process, trainer, self.mp_queue)
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self.__recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results

def start_predicting(self, trainer: "pl.Trainer") -> None:
self.spawn(self.new_process, trainer, self.mp_queue)
def start_predicting(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self.__recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results

def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]:
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
"""Spawn processes that run the given function.

Args:
Expand Down Expand Up @@ -191,9 +185,7 @@ def _worker_setup(self, process_idx: int):
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
)

def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
self.mp_queue = mp_queue

def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
# move the model to the correct device
self.model_to_device()

Expand All @@ -208,28 +200,11 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
self.barrier()

results = trainer.run_stage()

# persist info in ddp_spawn
self.__transfer_distrib_spawn_state_on_fit_end(trainer, results)
outputs = self.__collect_rank_zero_results(trainer, results)

# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()

def post_dispatch(self, trainer: "pl.Trainer"):
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
self._results = self.mp_queue.get()
# get the `callback_metrics` and set it to the trainer
# only in case the user does not override it.
# TODO: Remove the if in v1.7
if is_overridden("get_from_queue", self.lightning_module):
self.lightning_module.get_from_queue(self.mp_queue)
else:
self.get_from_queue(trainer, self.mp_queue)

# recover the weights of the processes trained in the children
self.__recover_child_process_weights(best_path, last_path)
return outputs

def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
Expand Down Expand Up @@ -268,7 +243,7 @@ def determine_ddp_device_ids(self):
return None
return [self.root_device.index]

def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None:
def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_warn("cleaning up ddp environment...")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
Expand All @@ -285,28 +260,37 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)

# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
# adds the `callback_metrics` to the queue
# TODO: Remove the if in v1.7
extra = _FakeQueue()
if is_overridden("add_to_queue", self.lightning_module):
self.lightning_module.add_to_queue(self.mp_queue)
# TODO: Remove the if in v1.7
self.lightning_module.add_to_queue(extra)
else:
self.add_to_queue(trainer, self.mp_queue)
self.add_to_queue(trainer, extra)

def __recover_child_process_weights(self, best_path, last_path):
return _SpawnOutput(best_model_path, last_path, results, extra)

def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None:
# transfer back the best path to the trainer
if self.lightning_module.trainer.checkpoint_callback:
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also best score
self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path

# TODO: pass also best score
# load last weights
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
ckpt = self.checkpoint_io.load_checkpoint(last_path, map_location=(lambda storage, loc: storage))
if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
ckpt = self.checkpoint_io.load_checkpoint(
spawn_output.last_path, map_location=(lambda storage, loc: storage)
)
self.lightning_module.load_state_dict(ckpt)

# get the `callback_metrics` and set it to the trainer
if is_overridden("get_from_queue", self.lightning_module):
# only in case the user does not override it.
# TODO: Remove the if in v1.7
self.lightning_module.get_from_queue(spawn_output.extra)
else:
self.get_from_queue(trainer, spawn_output.extra)

def barrier(self, *args, **kwargs) -> None:
if not distributed_available():
return
Expand Down Expand Up @@ -372,23 +356,25 @@ def post_training_step(self):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True

def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.

Args:
trainer: reference to the Trainer.
queue: the instance of the queue to append the data.
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
queue.put(callback_metrics)

def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.

Args:
trainer: reference to the Trainer.
queue: the instance of the queue from where to get the data.
"""
# NOTE: `add_to_queue` needs to be called before
Expand All @@ -413,3 +399,23 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()


class _FakeQueue(UserList):
"""Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list."""

def get(self) -> Any:
return self.pop(0)

def put(self, item: Any) -> None:
self.append(item)

def empty(self) -> bool:
return len(self) == 0


class _SpawnOutput(NamedTuple):
best_model_path: Optional[_PATH]
last_path: Optional[_PATH]
trainer_results: Any
extra: _FakeQueue
9 changes: 4 additions & 5 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from multiprocessing.queues import SimpleQueue
from typing import Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple

import torch
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import _StrategyType
Expand Down Expand Up @@ -115,12 +114,12 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None:
def post_training_step(self):
pass

def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]:
# Ensure that the scaler points to the correct process group
# which is re-initialized in a new process
if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin):
self._precision_plugin.scaler = ShardedGradScaler()
return super().new_process(trainer, mp_queue)
return super().new_process(trainer)

@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:
Expand Down
Loading