From d05363fca612e8cf8751dc5abac21c025bb5c9ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 04:53:23 +0200 Subject: [PATCH 01/41] improve spawn queue --- pl_examples/bug_report_model.py | 3 + pytorch_lightning/core/lightning.py | 4 +- .../plugins/training_type/ddp_spawn.py | 135 ++++++++++-------- .../plugins/training_type/sharded_spawn.py | 4 +- .../training_type/training_type_plugin.py | 16 +-- pytorch_lightning/trainer/trainer.py | 14 +- tests/plugins/test_ddp_spawn_plugin.py | 12 +- 7 files changed, 101 insertions(+), 87 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 270b0cd2abe8d..6a804c981033f 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -56,6 +56,9 @@ def run(): num_sanity_val_steps=0, max_epochs=1, enable_model_summary=False, + accelerator="cpu", + strategy="ddp_spawn", + devices=2, ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) trainer.test(model, dataloaders=test_data) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ca4b2af7eee17..7989d2bf87c65 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1992,7 +1992,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + def add_to_queue(self, queue: List[Any]) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -2006,7 +2006,7 @@ def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + def get_from_queue(self, queue: List[Any]) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index c72cc7f31d0cc..8f8a3b14cb01c 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Tuple import numpy as np import torch @@ -38,7 +38,7 @@ rank_zero_deprecation, rank_zero_warn, ) -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import distributed_available @@ -91,7 +91,6 @@ def __init__( self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 - self.mp_queue = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper @@ -120,11 +119,11 @@ def sync_batchnorm(self, sync_batchnorm: bool) -> None: def local_rank(self) -> int: return self._local_rank - def __getstate__(self): - """Makes this plugin pickleable without destroying the queue in the current process.""" - state = self.__dict__.copy() - state["mp_queue"] = None - return state + # def __getstate__(self): + # """Makes this plugin pickleable without destroying the queue in the current process.""" + # state = self.__dict__.copy() + # state["mp_queue"] = None # TODO: is this anymoe needed? + # return state def __setstate__(self, state): self.__dict__ = state @@ -144,9 +143,6 @@ def _is_single_process_single_device(self): def setup(self) -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - # pass in a state q - smp = mp.get_context("spawn") - self.mp_queue = smp.SimpleQueue() def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" @@ -163,18 +159,24 @@ def set_world_ranks(self, process_idx: int = 0) -> None: def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} - def start_training(self, trainer: "pl.Trainer") -> None: - self.spawn(self.new_process, trainer, self.mp_queue) + def start_training(self, trainer: "pl.Trainer") -> Any: + best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) + self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] + return results def start_evaluating(self, trainer: "pl.Trainer") -> None: - self.spawn(self.new_process, trainer, self.mp_queue) + best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) + self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) + return results def start_predicting(self, trainer: "pl.Trainer") -> None: - self.spawn(self.new_process, trainer, self.mp_queue) + best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) + self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) + return results - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Spawn processes that run the given function. Args: @@ -185,11 +187,18 @@ def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None: These arguments must be pickleable. """ os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - mp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs()) + smp = mp.get_context("spawn") + mp_queue = smp.SimpleQueue() + mp.spawn(self._wrapped_function, args=(function, args, kwargs, mp_queue), nprocs=self.num_processes) + return mp_queue.get() - def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any) -> None: + def _wrapped_function( + self, process_idx: int, function: Callable, args: Any, kwargs: Any, mp_queue: SimpleQueue + ) -> None: self._worker_setup(process_idx) - function(*args, **kwargs) + result = function(*args, **kwargs) + if self.is_global_zero: + mp_queue.put(move_data_to_device(result, "cpu")) def _worker_setup(self, process_idx: int): reset_seed() @@ -197,9 +206,7 @@ def _worker_setup(self, process_idx: int): rank_zero_only.rank = self.global_rank init_ddp_connection(self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size) - def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: - self.mp_queue = mp_queue - + def new_process(self, trainer: "pl.Trainer") -> Any: # move the model to the correct device self.model_to_device() @@ -214,28 +221,16 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: self.barrier() results = trainer.run_stage() - - # persist info in ddp_spawn - self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) + outputs = self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() + return outputs def post_dispatch(self, trainer: "pl.Trainer"): # restore main state with best weights - best_path = self.mp_queue.get() - last_path = self.mp_queue.get() - self._results = self.mp_queue.get() - # get the `callback_metrics` and set it to the trainer - # only in case the user does not override it. - # TODO: Remove the if in v1.7 - if is_overridden("get_from_queue", self.lightning_module): - self.lightning_module.get_from_queue(self.mp_queue) - else: - self.get_from_queue(trainer, self.mp_queue) - # recover the weights of the processes trained in the children - self.__recover_child_process_weights(best_path, last_path) + pass def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` @@ -276,34 +271,40 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + def __transfer_distrib_spawn_state_on_fit_end( + self, trainer: "pl.Trainer", results: Any + ) -> Optional[Tuple[Optional[str], Optional[str], Any, List[Any]]]: + checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.global_rank == 0 and self.mp_queue is not None: - rank_zero_warn("cleaning up ddp environment...") - - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(state_dict, last_path) - - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) - # adds the `callback_metrics` to the queue - # TODO: Remove the if in v1.7 - if is_overridden("add_to_queue", self.lightning_module): - self.lightning_module.add_to_queue(self.mp_queue) - else: - self.add_to_queue(trainer, self.mp_queue) - - def __recover_child_process_weights(self, best_path, last_path): + if not self.is_global_zero: + return + + rank_zero_warn("cleaning up ddp environment...") + + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + atomic_save(state_dict, last_path) + + extra = [] + # adds the `callback_metrics` to the queue + # TODO: Remove the if in v1.7 + if is_overridden("add_to_queue", self.lightning_module): + self.lightning_module.add_to_queue(extra) + else: + self.add_to_queue(trainer, extra) + + return best_model_path, last_path, results, extra + + def __recover_child_process_weights( + self, best_path: Optional[str], last_path: Optional[str], extra: List[Any], trainer + ) -> None: # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path @@ -314,6 +315,14 @@ def __recover_child_process_weights(self, best_path, last_path): ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) + # get the `callback_metrics` and set it to the trainer + # only in case the user does not override it. + # TODO: Remove the if in v1.7 + if is_overridden("get_from_queue", self.lightning_module): + self.lightning_module.get_from_queue(extra) + else: + self.get_from_queue(trainer, extra) + def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return @@ -379,7 +388,7 @@ def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None: + def add_to_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -389,9 +398,9 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.Simpl callback_metrics: dict = apply_to_collection( trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() ) # send as numpy to avoid issues with memory sharing - queue.put(callback_metrics) + queue.append(callback_metrics) - def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None: + def get_from_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. @@ -399,7 +408,7 @@ def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.Sim queue: the instance of the queue from where to get the data. """ # NOTE: `add_to_queue` needs to be called before - callback_metrics: dict = queue.get() + callback_metrics: dict = queue.pop(0) trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) @classmethod diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 78b54d029a5f6..9fd55a4e79bc0 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -101,13 +101,13 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: + def new_process(self, trainer: "pl.Trainer") -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process precision_plugin = trainer.accelerator.precision_plugin if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): precision_plugin.scaler = ShardedGradScaler() - return super().new_process(trainer, mp_queue) + return super().new_process(trainer) @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 481b9ee1c4087..39227b0008d84 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -35,7 +35,6 @@ class TrainingTypePlugin(ABC): def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: self._model: Optional[Module] = None - self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io @@ -188,7 +187,8 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: start-method send the result to the master process through a `multiprocessing queue (shared memory) `_. """ - return self._results + # TODO: deprecate this + return None def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() @@ -202,17 +202,17 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) - def start_training(self, trainer: "pl.Trainer") -> None: + def start_training(self, trainer: "pl.Trainer") -> Any: # double dispatch to initiate the training loop - self._results = trainer.run_stage() + return trainer.run_stage() - def start_evaluating(self, trainer: "pl.Trainer") -> None: + def start_evaluating(self, trainer: "pl.Trainer") -> Any: # double dispatch to initiate the test loop - self._results = trainer.run_stage() + return trainer.run_stage() - def start_predicting(self, trainer: "pl.Trainer") -> None: + def start_predicting(self, trainer: "pl.Trainer") -> Any: # double dispatch to initiate the predicting loop - self._results = trainer.run_stage() + return trainer.run_stage() def training_step(self, *args, **kwargs): return self.model.training_step(*args, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f9c18d0a8462f..4548ccc8702bc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1097,9 +1097,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.checkpoint_connector.restore_training_state() # dispatch `start_training` or `start_evaluating` or `start_predicting` - self._dispatch() + results = self._dispatch() - # plugin will finalized fitting (e.g. ddp_spawn will load trained model) + # TODO: needed? self._post_dispatch() # ---------------------------- @@ -1118,7 +1118,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.state.status = TrainerStatus.FINISHED self.state.stage = None - return self.training_type_plugin.results + return results def _pre_dispatch(self): self.accelerator.pre_dispatch(self) @@ -1170,13 +1170,13 @@ def _post_dispatch(self): self._active_loop.teardown() self.logger_connector.teardown() - def _dispatch(self): + def _dispatch(self) -> Any: if self.evaluating: - self.training_type_plugin.start_evaluating(self) + return self.training_type_plugin.start_evaluating(self) elif self.predicting: - self.training_type_plugin.start_predicting(self) + return self.training_type_plugin.start_predicting(self) else: - self.training_type_plugin.start_training(self) + return self.training_type_plugin.start_training(self) def run_stage(self): self.accelerator.dispatch(self) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index c389cf9290c78..804cf3d9e2ee0 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Any + import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -38,11 +40,11 @@ def validation_step(self, batch, batch_idx): return super().validation_step(batch, batch_idx) def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: - queue.put("test_val") + queue.append("test_val") return super().add_to_queue(queue) def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: - self.test_val = queue.get() + self.test_val = queue.pop(0) return super().get_from_queue(queue) @@ -83,11 +85,11 @@ def test_ddp_spawn_extra_parameters(tmpdir): class TestDDPSpawnPlugin(DDPSpawnPlugin): def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None: - queue.put("new_test_val") + queue.append("new_test_val") return super().add_to_queue(trainer, queue) - def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None: - self.new_test_val = queue.get() + def get_from_queue(self, trainer: Trainer, queue: List[Any]) -> None: + self.new_test_val = queue.pop(0) return super().get_from_queue(trainer, queue) From d650e2641e7649363b2ec7a8192b60f686fff2da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 05:04:10 +0200 Subject: [PATCH 02/41] clean up --- .../plugins/training_type/ddp_spawn.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 8f8a3b14cb01c..517ea4042767d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -119,10 +119,11 @@ def sync_batchnorm(self, sync_batchnorm: bool) -> None: def local_rank(self) -> int: return self._local_rank + # TODO: this should no longer be needed # def __getstate__(self): # """Makes this plugin pickleable without destroying the queue in the current process.""" # state = self.__dict__.copy() - # state["mp_queue"] = None # TODO: is this anymoe needed? + # state["mp_queue"] = None # return state def __setstate__(self, state): @@ -142,6 +143,7 @@ def _is_single_process_single_device(self): return True def setup(self) -> None: + # TODO: is this needed here? already getting set in spawn() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) def _setup_model(self, model: Module) -> DistributedDataParallel: @@ -227,11 +229,6 @@ def new_process(self, trainer: "pl.Trainer") -> Any: trainer._call_teardown_hook() return outputs - def post_dispatch(self, trainer: "pl.Trainer"): - # restore main state with best weights - - pass - def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases @@ -292,10 +289,10 @@ def __transfer_distrib_spawn_state_on_fit_end( last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) atomic_save(state_dict, last_path) - extra = [] # adds the `callback_metrics` to the queue - # TODO: Remove the if in v1.7 + extra = [] if is_overridden("add_to_queue", self.lightning_module): + # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) else: self.add_to_queue(trainer, extra) @@ -316,9 +313,9 @@ def __recover_child_process_weights( self.lightning_module.load_state_dict(ckpt) # get the `callback_metrics` and set it to the trainer - # only in case the user does not override it. - # TODO: Remove the if in v1.7 if is_overridden("get_from_queue", self.lightning_module): + # only in case the user does not override it. + # TODO: Remove the if in v1.7 self.lightning_module.get_from_queue(extra) else: self.get_from_queue(trainer, extra) @@ -393,6 +390,7 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: sharing, we cast the data to numpy. Args: + trainer: reference to the Trainer. queue: the instance of the queue to append the data. """ callback_metrics: dict = apply_to_collection( @@ -405,6 +403,7 @@ def get_from_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: we cast back the data to ``torch.Tensor``. Args: + trainer: reference to the Trainer. queue: the instance of the queue from where to get the data. """ # NOTE: `add_to_queue` needs to be called before From 5fda23a99a7d3a5501e064735bdda18fe9fe8a6c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Oct 2021 03:10:00 +0000 Subject: [PATCH 03/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- tests/plugins/test_ddp_spawn_plugin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 517ea4042767d..8802a5d016433 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 804cf3d9e2ee0..3b293ea315f59 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Any +from typing import Any, List import torch from torch.nn.parallel.distributed import DistributedDataParallel From bcfb853fa6ad7ed5318c6e37eda569532cd44d2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:04:55 +0100 Subject: [PATCH 04/41] fix --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 9f5b193dcefd6..4862518a872e2 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -125,7 +125,7 @@ def distributed_sampler_kwargs(self): @property def _is_single_process_single_device(self): return True - + def setup(self, trainer: "pl.Trainer") -> None: # TODO: is this needed here? already getting set in spawn() os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) @@ -173,8 +173,6 @@ def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Any: function: The function to spawn processes from. *args: Optional positional arguments that will be passed to the function in addition to the process index. These arguments must be pickleable. - return_result: If ``True``, copies the output of the function from process 0 to the main process and - returns it. **kwargs: Optional named arguments that will be passed to the function in addition to the process index. These arguments must be pickleable. @@ -183,9 +181,9 @@ def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """ os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) context = mp.get_context("spawn") - return_queue = context.SimpleQueue() if return_result else None + return_queue = context.SimpleQueue() mp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), nprocs=self.num_processes) - return mp_queue.get() + return return_queue.get() def _wrapped_function( self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue From 97b4bf6046f2bbfe55431ad0d4110645103c3093 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 14:09:24 +0000 Subject: [PATCH 05/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/plugins/test_ddp_spawn_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index f85545e83ea83..4f165fe963a46 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List + import pytest import torch from torch.nn.parallel.distributed import DistributedDataParallel From 38b3a548731bbd9066ae994b26e0ac0434b19e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:14:18 +0100 Subject: [PATCH 06/41] rename --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 4862518a872e2..5214f9a12ce45 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -216,7 +216,7 @@ def new_process(self, trainer: "pl.Trainer") -> Any: self.barrier() results = trainer.run_stage() - outputs = self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) + outputs = self.__collect_rank_zero_results(trainer, results) # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() @@ -259,7 +259,7 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def __transfer_distrib_spawn_state_on_fit_end( + def __collect_rank_zero_results( self, trainer: "pl.Trainer", results: Any ) -> Optional[Tuple[Optional[str], Optional[str], Any, List[Any]]]: diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 5ef8a46d7127f..7d36dbfef8d33 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -188,7 +188,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: results = trainer.run_stage() - self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) + self.__collect_rank_zero_results(trainer, results) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") @@ -207,7 +207,7 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> None: checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None From 955b6c8f6c66bc6a2578ef919187a98bad1fea61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:14:24 +0100 Subject: [PATCH 07/41] delete dead code --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5214f9a12ce45..efe8df41d570e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -103,16 +103,6 @@ def num_nodes(self, num_nodes: int) -> None: def local_rank(self) -> int: return self._local_rank - # TODO: this should no longer be needed - # def __getstate__(self): - # """Makes this plugin pickleable without destroying the queue in the current process.""" - # state = self.__dict__.copy() - # state["mp_queue"] = None - # return state - - def __setstate__(self, state): - self.__dict__ = state - @property def root_device(self): return self.parallel_devices[self.local_rank] From f3216b21b49d04944912f546df1230f64043f9ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:24:41 +0100 Subject: [PATCH 08/41] clean up --- .../plugins/training_type/training_type_plugin.py | 13 ------------- pytorch_lightning/trainer/trainer.py | 1 - 2 files changed, 14 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 9923500574bfb..0dcc31c4482d6 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -287,19 +287,6 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" return unwrap_lightning_module(self._model) if self._model is not None else None - @property - def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: - """Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run. - - The result is - cached instead of returned directly, because some plugins require transmitting the results from one - multiprocessing context to another in a separate step. For example, the plugins that use the "spawn" - start-method send the result to the main process through a - `multiprocessing queue (shared memory) `_. - """ - # TODO: deprecate this - return None - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 87d71db7576d4..f2ec7ef4b7332 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1162,7 +1162,6 @@ def _run( # dispatch `start_training` or `start_evaluating` or `start_predicting` results = self._dispatch() - # TODO: needed? self._post_dispatch() # ---------------------------- From 2d00231acb52eb8a6f55890ff90ce5303995d80f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:26:11 +0100 Subject: [PATCH 09/41] update lite --- pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index fede7f5df7291..3a6a814ce9200 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -403,7 +403,7 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: run_method = partial(self._run_with_sharded_context, run_method) if isinstance(self._strategy, DDPSpawnPlugin): - return self._strategy.spawn(run_method, *args, return_result=True, **kwargs) + return self._strategy.spawn(run_method, *args, **kwargs) else: return run_method(*args, **kwargs) From 7aa36461c05632f016ddcf07f9617ba49dad9f06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:40:28 +0100 Subject: [PATCH 10/41] retain the queue interface in hooks --- pytorch_lightning/core/lightning.py | 5 +++-- .../plugins/training_type/ddp_spawn.py | 21 +++++++++++++++---- tests/plugins/test_ddp_spawn_plugin.py | 6 +++--- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 814caa3e6cf6d..3f8a547149f1a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -36,6 +36,7 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO +from pytorch_lightning.plugins.training_type.ddp_spawn import _SimpleQueue from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import ( _IS_WINDOWS, @@ -1928,7 +1929,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: List[Any]) -> None: + def add_to_queue(self, queue: _SimpleQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1942,7 +1943,7 @@ def add_to_queue(self, queue: List[Any]) -> None: if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: List[Any]) -> None: + def get_from_queue(self, queue: _SimpleQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index efe8df41d570e..cba316e2c514e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -366,7 +366,7 @@ def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: + def add_to_queue(self, trainer: "pl.Trainer", queue: "_SimpleQueue") -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -377,9 +377,9 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: callback_metrics: dict = apply_to_collection( trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() ) # send as numpy to avoid issues with memory sharing - queue.append(callback_metrics) + queue.put(callback_metrics) - def get_from_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: + def get_from_queue(self, trainer: "pl.Trainer", queue: "_SimpleQueue") -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. @@ -388,7 +388,7 @@ def get_from_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: queue: the instance of the queue from where to get the data. """ # NOTE: `add_to_queue` needs to be called before - callback_metrics: dict = queue.pop(0) + callback_metrics: dict = queue.get() trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) @classmethod @@ -422,3 +422,16 @@ def _clean_logger(trainer: "pl.Trainer") -> None: # we want to make sure these are closed before we spawn our own threads. # assuming nothing else references the experiment object, python should instantly `__del__` it. logger._experiment = None + + +class _SimpleQueue(list): + """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` using the Python list interface.""" + + def get(self) -> Any: + return self.pop(0) + + def put(self, item: Any) -> None: + self.append(item) + + def empty(self) -> bool: + return len(self) == 0 diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 4f165fe963a46..21736837d4ed1 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -86,12 +86,12 @@ def test_ddp_spawn_extra_parameters(tmpdir): class TestDDPSpawnPlugin(DDPSpawnPlugin): - def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None: - queue.append("new_test_val") + def add_to_queue(self, trainer, queue) -> None: + queue.put("new_test_val") return super().add_to_queue(trainer, queue) def get_from_queue(self, trainer: Trainer, queue: List[Any]) -> None: - self.new_test_val = queue.pop(0) + self.new_test_val = queue.get() return super().get_from_queue(trainer, queue) From fb0c0d8bb4f35e8a59f01527d4c18b052f8b40be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:42:13 +0100 Subject: [PATCH 11/41] update tests --- tests/plugins/test_ddp_spawn_plugin.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 21736837d4ed1..6ea265a4bb575 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List import pytest import torch @@ -40,12 +39,12 @@ def validation_step(self, batch, batch_idx): self.log(self.name, self.val) return super().validation_step(batch, batch_idx) - def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: - queue.append("test_val") + def add_to_queue(self, queue) -> None: + queue.put("test_val") return super().add_to_queue(queue) - def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: - self.test_val = queue.pop(0) + def get_from_queue(self, queue) -> None: + self.test_val = queue.get() return super().get_from_queue(queue) @@ -90,7 +89,7 @@ def add_to_queue(self, trainer, queue) -> None: queue.put("new_test_val") return super().add_to_queue(trainer, queue) - def get_from_queue(self, trainer: Trainer, queue: List[Any]) -> None: + def get_from_queue(self, trainer: Trainer, queue) -> None: self.new_test_val = queue.get() return super().get_from_queue(trainer, queue) From 7e6c75ea5ee5ac4bebd909e033ac719c63e50c6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:43:47 +0100 Subject: [PATCH 12/41] _notebooks --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index a2fb6468112b7..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a2fb6468112b7e1dad501c3b6a17533a4adfeabc From b7efc5052b156aec31156d53d06413672b63367e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:44:06 +0100 Subject: [PATCH 13/41] reset notebooks --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..0c325829101d5 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 0c325829101d5a6ebf32ed99bbf5b09badf04a59 From 84ca8b4ad50230da1fdbd314853cc30ff20c274d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:58:14 +0100 Subject: [PATCH 14/41] avoid circular import --- pytorch_lightning/core/lightning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3f8a547149f1a..736e8f5a9b560 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -36,7 +36,6 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO -from pytorch_lightning.plugins.training_type.ddp_spawn import _SimpleQueue from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import ( _IS_WINDOWS, @@ -1929,7 +1928,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: _SimpleQueue) -> None: + def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._SimpleQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1943,7 +1942,7 @@ def add_to_queue(self, queue: _SimpleQueue) -> None: if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: _SimpleQueue) -> None: + def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._SimpleQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. From 965c7245aa7785012ec671b5e3963409ab846235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 16:04:29 +0100 Subject: [PATCH 15/41] fix unused imports --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 1 - pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index e79c0aa9a1c24..c9a968fa94fbd 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from multiprocessing.queues import SimpleQueue from typing import Dict, Generator, List, Optional, Tuple import torch diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index f97bec44bc0ea..ef5f78a1e09f3 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -30,7 +30,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp -from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT +from pytorch_lightning.utilities.types import _PATH TBroadcast = TypeVar("TBroadcast") From 1aae8ddc78848fda16215b92cc1e83e00966d401 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 16:08:11 +0100 Subject: [PATCH 16/41] reset debugging script --- pl_examples/bug_report/bug_report_model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pl_examples/bug_report/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py index 5e43eeca17308..7739630237d32 100644 --- a/pl_examples/bug_report/bug_report_model.py +++ b/pl_examples/bug_report/bug_report_model.py @@ -57,9 +57,6 @@ def run(): num_sanity_val_steps=0, max_epochs=1, enable_model_summary=False, - accelerator="cpu", - strategy="ddp_spawn", - devices=2, ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) trainer.test(model, dataloaders=test_data) From 4b998db77a32c1fc5b100ed0a618fe305cb843ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 16:20:43 +0100 Subject: [PATCH 17/41] typing _ExtraQueue --- pytorch_lightning/core/lightning.py | 4 ++-- .../plugins/training_type/ddp_spawn.py | 17 ++++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 736e8f5a9b560..68611b93079c0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1928,7 +1928,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._SimpleQueue) -> None: + def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._ExtraQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1942,7 +1942,7 @@ def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._SimpleQueue) - if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._SimpleQueue) -> None: + def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._ExtraQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index cba316e2c514e..59849bc08c389 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -117,7 +117,6 @@ def _is_single_process_single_device(self): return True def setup(self, trainer: "pl.Trainer") -> None: - # TODO: is this needed here? already getting set in spawn() os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) super().setup(trainer) @@ -191,7 +190,7 @@ def _worker_setup(self, process_idx: int): self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size ) - def new_process(self, trainer: "pl.Trainer") -> Any: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: # move the model to the correct device self.model_to_device() @@ -251,7 +250,7 @@ def determine_ddp_device_ids(self): def __collect_rank_zero_results( self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, List[Any]]]: + ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -271,7 +270,7 @@ def __collect_rank_zero_results( atomic_save(state_dict, last_path) # adds the `callback_metrics` to the queue - extra = [] + extra = _ExtraQueue() if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) @@ -281,7 +280,7 @@ def __collect_rank_zero_results( return best_model_path, last_path, results, extra def __recover_child_process_weights( - self, best_path: Optional[str], last_path: Optional[str], extra: List[Any], trainer + self, best_path: Optional[str], last_path: Optional[str], extra: "_ExtraQueue", trainer ) -> None: # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: @@ -366,7 +365,7 @@ def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: "_SimpleQueue") -> None: + def add_to_queue(self, trainer: "pl.Trainer", queue: "_ExtraQueue") -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -379,7 +378,7 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: "_SimpleQueue") -> None: ) # send as numpy to avoid issues with memory sharing queue.put(callback_metrics) - def get_from_queue(self, trainer: "pl.Trainer", queue: "_SimpleQueue") -> None: + def get_from_queue(self, trainer: "pl.Trainer", queue: "_ExtraQueue") -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. @@ -424,8 +423,8 @@ def _clean_logger(trainer: "pl.Trainer") -> None: logger._experiment = None -class _SimpleQueue(list): - """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` using the Python list interface.""" +class _ExtraQueue(list): + """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" def get(self) -> Any: return self.pop(0) From 5871a4bacc64a868744618764ef341bb2a7ac6f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 16:55:38 +0100 Subject: [PATCH 18/41] bring changes to tpu_spawn plugin --- .../plugins/training_type/ddp_spawn.py | 2 +- .../plugins/training_type/tpu_spawn.py | 51 +++++++++++-------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 59849bc08c389..4eac619a08888 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -258,7 +258,7 @@ def __collect_rank_zero_results( # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if not self.is_global_zero: + if self.local_rank != 0: return rank_zero_warn("cleaning up ddp environment...") diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7d36dbfef8d33..d3f03eb122ebc 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -16,7 +16,7 @@ import re import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Tuple import torch import torch.multiprocessing as mp @@ -28,7 +28,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin, _ExtraQueue from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters @@ -207,28 +207,37 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> None: + # TODO: this implementation is identical to the one in the super class, up to the self.save() call + def __collect_rank_zero_results( + self, trainer: "pl.Trainer", results: Any + ) -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: + checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.mp_queue is not None: - rank_zero_warn("cleaning up tpu spawn environment...") + if self.local_rank != 0: + return + + rank_zero_warn("cleaning up ddp environment...") - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.save(state_dict, last_path) + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.save(state_dict, last_path) - if self.local_rank == 0: - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) - self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue + # adds the `callback_metrics` to the queue + extra = _ExtraQueue() + if is_overridden("add_to_queue", self.lightning_module): + # TODO: Remove the if in v1.7 + self.lightning_module.add_to_queue(extra) + else: + self.add_to_queue(trainer, extra) + + return best_model_path, last_path, results, extra def save(self, state_dict: Dict, path: _PATH) -> None: xm.save(state_dict, path) @@ -275,18 +284,18 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st "start_method": self.start_method, } - def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: context = mp.get_context(self.start_method or "fork") - return_queue = context.SimpleQueue() if return_result else None + return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) - return return_queue.get() if return_result else None + return return_queue.get() def _wrapped_function( - self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: Optional[SimpleQueue] + self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue ) -> None: self._worker_setup(process_idx) result = function(*args, **kwargs) - if return_queue is not None and self.local_rank == 0: + if self.local_rank == 0: return_queue.put(move_data_to_device(result, "cpu")) self.barrier("end-process") From aa76840fdffad9047ee5ff9a4474df87d9ffb26b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:18:05 +0100 Subject: [PATCH 19/41] unify --- .../plugins/training_type/ddp_spawn.py | 2 +- .../plugins/training_type/tpu_spawn.py | 32 ------------------- 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 4eac619a08888..8ab51f3fb0ee2 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -267,7 +267,7 @@ def __collect_rank_zero_results( last_path = None if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(state_dict, last_path) + self.save_checkpoint(state_dict, last_path) # adds the `callback_metrics` to the queue extra = _ExtraQueue() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d3f03eb122ebc..da1ed20dd405c 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -207,38 +207,6 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - # TODO: this implementation is identical to the one in the super class, up to the self.save() call - def __collect_rank_zero_results( - self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: - - checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None - - # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self.lightning_module.state_dict() - - if self.local_rank != 0: - return - - rank_zero_warn("cleaning up ddp environment...") - - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.save(state_dict, last_path) - - # adds the `callback_metrics` to the queue - extra = _ExtraQueue() - if is_overridden("add_to_queue", self.lightning_module): - # TODO: Remove the if in v1.7 - self.lightning_module.add_to_queue(extra) - else: - self.add_to_queue(trainer, extra) - - return best_model_path, last_path, results, extra - def save(self, state_dict: Dict, path: _PATH) -> None: xm.save(state_dict, path) From 37f9db9f9082313c388d0adb45804b112f697917 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:18:44 +0100 Subject: [PATCH 20/41] remove dead code --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index da1ed20dd405c..ec773f1fc70ac 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -207,9 +207,6 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def save(self, state_dict: Dict, path: _PATH) -> None: - xm.save(state_dict, path) - def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: return obj From d68cb35abd187c5abf80b29b3b449b34608a03e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 17:20:17 +0000 Subject: [PATCH 21/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index ec773f1fc70ac..1fec30cc5fb9a 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -16,7 +16,7 @@ import re import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.multiprocessing as mp @@ -28,7 +28,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin, _ExtraQueue +from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters From dd80be94e74fb190c1442b41633bd2a1aa98ea7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:25:05 +0100 Subject: [PATCH 22/41] remove queue from tpu spawn --- .../plugins/training_type/tpu_spawn.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index ec773f1fc70ac..01f3c8dac9134 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -122,7 +122,7 @@ def pre_dispatch(self): os.environ["PT_XLA_DEBUG"] = str(1) def setup(self, trainer: "pl.Trainer") -> None: - self.create_mp_queue() + self.start_method = "fork" if not self.setup_optimizers_in_pre_dispatch: self.setup_optimizers(trainer) self.setup_precision_plugin() @@ -138,11 +138,6 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: def _setup_model(self, model: Module) -> Module: return model - def create_mp_queue(self): - self.start_method = "fork" - smp = mp.get_context(self.start_method) - self.mp_queue = smp.SimpleQueue() - @property def distributed_sampler_kwargs(self) -> Dict[str, int]: return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) @@ -168,9 +163,7 @@ def init_dist_connection(self, global_rank: int, world_size: int) -> None: def set_world_ranks(self, process_idx: int = 0) -> None: pass - def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: - self.mp_queue = mp_queue - + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() @@ -188,7 +181,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: results = trainer.run_stage() - self.__collect_rank_zero_results(trainer, results) + outputs = self.__collect_rank_zero_results(trainer, results) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") @@ -199,6 +192,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() + return outputs def model_to_device(self) -> None: self.model = self.wrapped_model.to(self.root_device) From f97eee894fa63c217019eb0b7687a6dc99656fc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:25:22 +0100 Subject: [PATCH 23/41] type annotation for new_process --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index c9a968fa94fbd..b6dff138803ed 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, List, Optional, Tuple +from typing import Dict, Generator, List, Optional, Tuple, Any import torch from torch.nn import Module @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin, _ExtraQueue from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType @@ -114,7 +114,7 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer") -> None: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): From 459121ebbc5a74d5cb39f7acefb2b66e94f6e97d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 17:27:05 +0000 Subject: [PATCH 24/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index b6dff138803ed..55b9253d3101c 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, List, Optional, Tuple, Any +from typing import Any, Dict, Generator, List, Optional, Tuple import torch from torch.nn import Module @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin, _ExtraQueue +from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType From 72535ff34d5de83c1bf87a02a61d383ea338ea28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:30:13 +0100 Subject: [PATCH 25/41] unused imports --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 1 - pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 8ab51f3fb0ee2..b9aa02bed397b 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -35,7 +35,6 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device -from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as _group diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 0ed9563177274..e81cb5bfbe16b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -13,7 +13,6 @@ # limitations under the License. import io import os -import re import time from multiprocessing.queues import SimpleQueue from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -30,8 +29,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters +from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp From 61192df022c6c95ec6b65c872068c41299a0f2a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:59:30 +0100 Subject: [PATCH 26/41] move check --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index b9aa02bed397b..370e9ceda1525 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -257,9 +257,6 @@ def __collect_rank_zero_results( # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.local_rank != 0: - return - rank_zero_warn("cleaning up ddp environment...") # save the last weights @@ -268,6 +265,9 @@ def __collect_rank_zero_results( last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) self.save_checkpoint(state_dict, last_path) + if self.local_rank != 0: + return + # adds the `callback_metrics` to the queue extra = _ExtraQueue() if is_overridden("add_to_queue", self.lightning_module): From 801f529dbc8698d5d3b2a197c364f49fbeae4274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 23:05:30 +0100 Subject: [PATCH 27/41] revert --- .../plugins/training_type/tpu_spawn.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index e81cb5bfbe16b..58585b9b79a52 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -199,6 +199,29 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) + def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + checkpoint_callback = trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + + # requires to compute the state_dict on all processes in case Metrics are present + state_dict = self.lightning_module.state_dict() + + if self.mp_queue is not None: + rank_zero_warn("cleaning up tpu spawn environment...") + + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) + + if self.local_rank == 0: + # todo, pass complete checkpoint as state dictionary + self.mp_queue.put(best_model_path) + self.mp_queue.put(last_path) + self.mp_queue.put(results) + self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue + def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: return obj From 1cd258b70ad3c9c0869b74e71093f0ff4a24cbbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 23:13:37 +0100 Subject: [PATCH 28/41] collect results on tpu --- .../plugins/training_type/tpu_spawn.py | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 58585b9b79a52..64926f00af21a 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -13,6 +13,7 @@ # limitations under the License. import io import os +import re import time from multiprocessing.queues import SimpleQueue from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -29,7 +30,8 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp @@ -199,28 +201,35 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + def __collect_rank_zero_results( + self, trainer: "pl.Trainer", results: Any + ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.mp_queue is not None: - rank_zero_warn("cleaning up tpu spawn environment...") - - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) - - if self.local_rank == 0: - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) - self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue + rank_zero_warn("cleaning up tpu spawn environment...") + + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) + + if self.local_rank != 0: + return + + # adds the `callback_metrics` to the queue + extra = _ExtraQueue() + if is_overridden("add_to_queue", self.lightning_module): + # TODO: Remove the if in v1.7 + self.lightning_module.add_to_queue(extra) + else: + self.add_to_queue(trainer, extra) + + return best_model_path, last_path, results, extra def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: From 10ecbfd0da11d08fc6e12932b9153c2ae1352396 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 23:37:11 +0100 Subject: [PATCH 29/41] rename --- .../plugins/training_type/ddp_spawn.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 3781e40b8f343..e52d403de0ffb 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -134,19 +134,19 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st def start_training(self, trainer: "pl.Trainer") -> Any: best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) + self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] return results def start_evaluating(self, trainer: "pl.Trainer") -> None: best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) + self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) return results def start_predicting(self, trainer: "pl.Trainer") -> None: best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) + self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) return results def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Any: @@ -254,11 +254,11 @@ def __collect_rank_zero_results( rank_zero_warn("cleaning up ddp environment...") - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) if self.local_rank != 0: return @@ -273,7 +273,7 @@ def __collect_rank_zero_results( return best_model_path, last_path, results, extra - def __recover_child_process_weights( + def __recover_results_in_main_process( self, best_path: Optional[str], last_path: Optional[str], extra: "_ExtraQueue", trainer ) -> None: # transfer back the best path to the trainer From ebba63f4be7ad1a8f041c735938207adc6968ab9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 22:38:29 +0000 Subject: [PATCH 30/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index bbfd894db82ca..cf46f2224b437 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -32,7 +32,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters, rank_zero_warn +from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp From d7df4d93be44783a3695b6218d0b931bd39facf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 23:48:38 +0100 Subject: [PATCH 31/41] fix merge errors --- .../plugins/training_type/ddp_spawn.py | 6 +++--- .../plugins/training_type/tpu_spawn.py | 21 ++++++------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index e52d403de0ffb..393a93ea69a61 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -252,6 +252,9 @@ def __collect_rank_zero_results( # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() + if self.local_rank != 0: + return + rank_zero_warn("cleaning up ddp environment...") # save the last weights @@ -260,9 +263,6 @@ def __collect_rank_zero_results( last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) self.checkpoint_io.save_checkpoint(state_dict, last_path) - if self.local_rank != 0: - return - # adds the `callback_metrics` to the queue extra = _ExtraQueue() if is_overridden("add_to_queue", self.lightning_module): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index cf46f2224b437..abe297cdb3b58 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -204,7 +204,7 @@ def barrier(self, name: Optional[str] = None) -> None: def __collect_rank_zero_results( self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: + ) -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -213,11 +213,11 @@ def __collect_rank_zero_results( rank_zero_warn("cleaning up tpu spawn environment...") - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) if self.local_rank != 0: return @@ -303,17 +303,8 @@ def start_training(self, trainer: "pl.Trainer") -> None: # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] - self._clean_logger(trainer) return super().start_training(trainer) - def start_evaluating(self, trainer: "pl.Trainer") -> None: - self._clean_logger(trainer) - return super().start_evaluating(trainer) - - def start_predicting(self, trainer: "pl.Trainer") -> None: - self._clean_logger(trainer) - return super().start_predicting(trainer) - def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) From 4c547aa95b9a1cd2669a7e66da9749aecb9ed255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 23:49:07 +0100 Subject: [PATCH 32/41] fix merge errors --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index abe297cdb3b58..2dfe9709cc1cf 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -380,13 +380,3 @@ def checkpoint_io(self) -> CheckpointIO: @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") - - @staticmethod - def _clean_logger(trainer: "pl.Trainer") -> None: - loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] - for logger in loggers: - if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: - # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. - # we want to make sure these are closed before we spawn our own threads. - # assuming nothing else references the experiment object, python should instantly `__del__` it. - logger._experiment = None From e4e2a771f195fb445e1916b2b9df02579fdd707b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 04:01:07 +0100 Subject: [PATCH 33/41] re-add clean_logger --- .../plugins/training_type/tpu_spawn.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 2dfe9709cc1cf..ac69b501993ef 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -303,8 +303,17 @@ def start_training(self, trainer: "pl.Trainer") -> None: # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] + self._clean_logger(trainer) return super().start_training(trainer) + def start_evaluating(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) + return super().start_evaluating(trainer) + + def start_predicting(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) + return super().start_predicting(trainer) + def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) @@ -380,3 +389,13 @@ def checkpoint_io(self) -> CheckpointIO: @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") + + @staticmethod + def _clean_logger(trainer: "pl.Trainer") -> None: + loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] + for logger in loggers: + if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: + # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. + # we want to make sure these are closed before we spawn our own threads. + # assuming nothing else references the experiment object, python should instantly `__del__` it. + logger._experiment = None From acac29db559dc8a29d5295c2c93be7d1f54e7d14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 15:28:38 +0100 Subject: [PATCH 34/41] fix typing --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 393a93ea69a61..9a7439c42990d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -149,7 +149,7 @@ def start_predicting(self, trainer: "pl.Trainer") -> None: self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) return results - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: """Spawn processes that run the given function. Args: From 880c8fc8db0a013cb342004fea96ab7cda42e821 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 17:28:09 +0100 Subject: [PATCH 35/41] changelog entries --- CHANGELOG.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 629b28e392792..0e1a7993aa687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,7 +80,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649)) -- +- The `DDPSpawnPlugin` no longer overrides the `post_dispatch` plugin hook ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + + +- The `LightningModule.{add_to_queue,get_from_queue}` hooks no longer get a `torch.multiprocessing.SimpleQueue` and instead receive a list based queue ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + ### Deprecated @@ -188,6 +192,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867)) +- Removed the property `TrainingTypePlugin.results` and corresponding properties in subclasses ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + + +- Removed the `mp_queue` attribute from `DDPSpawnPlugin` and `TPUSpawnPlugin` ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + + ### Fixed - Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611)) From 7520adcce70d9abdbe23a7a28059a9f68a1e49e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Dec 2021 17:34:21 +0000 Subject: [PATCH 36/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9396c19060833..1f73595c37418 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -217,7 +217,7 @@ def __collect_rank_zero_results( if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) self.checkpoint_io.save_checkpoint(state_dict, last_path) - + # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training if self.local_rank != 0: return From 96f2749cea3d9dbf8dd0fe0b6509642fd01f0b24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 21:11:53 +0100 Subject: [PATCH 37/41] rename _ExtraQueue -> _FakeQueue --- pytorch_lightning/core/lightning.py | 4 ++-- .../plugins/training_type/ddp_spawn.py | 14 +++++++------- .../plugins/training_type/sharded_spawn.py | 4 ++-- .../plugins/training_type/tpu_spawn.py | 8 ++++---- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6ebc320a12e19..e02c9d32ecb80 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1917,7 +1917,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._ExtraQueue) -> None: + def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1931,7 +1931,7 @@ def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._ExtraQueue) -> if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._ExtraQueue) -> None: + def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 89250120a3711..7003c2617037f 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -184,7 +184,7 @@ def _worker_setup(self, process_idx: int): self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size ) - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: # move the model to the correct device self.model_to_device() @@ -244,7 +244,7 @@ def determine_ddp_device_ids(self): def __collect_rank_zero_results( self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: + ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: rank_zero_warn("cleaning up ddp environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -262,7 +262,7 @@ def __collect_rank_zero_results( self.checkpoint_io.save_checkpoint(state_dict, last_path) # adds the `callback_metrics` to the queue - extra = _ExtraQueue() + extra = _FakeQueue() if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) @@ -272,7 +272,7 @@ def __collect_rank_zero_results( return best_model_path, last_path, results, extra def __recover_results_in_main_process( - self, best_path: Optional[str], last_path: Optional[str], extra: "_ExtraQueue", trainer + self, best_path: Optional[str], last_path: Optional[str], extra: "_FakeQueue", trainer ) -> None: # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: @@ -357,7 +357,7 @@ def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: "_ExtraQueue") -> None: + def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -370,7 +370,7 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: "_ExtraQueue") -> None: ) # send as numpy to avoid issues with memory sharing queue.put(callback_metrics) - def get_from_queue(self, trainer: "pl.Trainer", queue: "_ExtraQueue") -> None: + def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. @@ -402,7 +402,7 @@ def teardown(self) -> None: torch.cuda.empty_cache() -class _ExtraQueue(list): +class _FakeQueue(list): """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" def get(self) -> Any: diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 55b9253d3101c..5e10155cc3ca5 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType @@ -114,7 +114,7 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1f73595c37418..4dcdb589150ca 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -29,7 +29,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters @@ -164,7 +164,7 @@ def init_dist_connection(self, global_rank: int, world_size: int) -> None: def set_world_ranks(self, process_idx: int = 0) -> None: pass - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() @@ -204,7 +204,7 @@ def barrier(self, name: Optional[str] = None) -> None: def __collect_rank_zero_results( self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: + ) -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]: rank_zero_warn("cleaning up tpu spawn environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -223,7 +223,7 @@ def __collect_rank_zero_results( return # adds the `callback_metrics` to the queue - extra = _ExtraQueue() + extra = _FakeQueue() if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) From 65d183c25113fcc43334865a57c092bbdd3cd841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 21:12:56 +0100 Subject: [PATCH 38/41] missing typing updates --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 7003c2617037f..3e9840e33c4ca 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -139,12 +139,12 @@ def start_training(self, trainer: "pl.Trainer") -> Any: trainer.optimizers = [] return results - def start_evaluating(self, trainer: "pl.Trainer") -> None: + def start_evaluating(self, trainer: "pl.Trainer") -> Any: best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) return results - def start_predicting(self, trainer: "pl.Trainer") -> None: + def start_predicting(self, trainer: "pl.Trainer") -> Any: best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) return results diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 4dcdb589150ca..7679fbffa8e50 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -299,18 +299,18 @@ def _worker_setup(self, process_idx: int): self.tpu_global_core_rank = xm.get_ordinal() rank_zero_only.rank = self.global_rank - def start_training(self, trainer: "pl.Trainer") -> None: + def start_training(self, trainer: "pl.Trainer") -> Any: # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] self._clean_logger(trainer) return super().start_training(trainer) - def start_evaluating(self, trainer: "pl.Trainer") -> None: + def start_evaluating(self, trainer: "pl.Trainer") -> Any: self._clean_logger(trainer) return super().start_evaluating(trainer) - def start_predicting(self, trainer: "pl.Trainer") -> None: + def start_predicting(self, trainer: "pl.Trainer") -> Any: self._clean_logger(trainer) return super().start_predicting(trainer) From 8c4e2e49a229794846c015fe414c6e04a4fce8d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 22:25:35 +0100 Subject: [PATCH 39/41] Introducing NamedTuple for spawn output typing --- .../plugins/training_type/ddp_spawn.py | 57 ++++++++++--------- .../plugins/training_type/tpu_spawn.py | 10 ++-- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 3e9840e33c4ca..f42edd0ae7763 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union, NamedTuple import numpy as np import torch @@ -45,7 +45,7 @@ from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT, _PATH if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook @@ -133,23 +133,23 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st return {"nprocs": self.num_processes} def start_training(self, trainer: "pl.Trainer") -> Any: - best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) + spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) + self.__recover_results_in_main_process(spawn_output, trainer) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] - return results + return spawn_output.trainer_results def start_evaluating(self, trainer: "pl.Trainer") -> Any: - best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) - return results + spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) + self.__recover_results_in_main_process(spawn_output, trainer) + return spawn_output.trainer_results def start_predicting(self, trainer: "pl.Trainer") -> Any: - best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) - return results + spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) + self.__recover_results_in_main_process(spawn_output, trainer) + return spawn_output.trainer_results - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: """Spawn processes that run the given function. Args: @@ -184,7 +184,7 @@ def _worker_setup(self, process_idx: int): self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size ) - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: + def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: # move the model to the correct device self.model_to_device() @@ -242,9 +242,7 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def __collect_rank_zero_results( - self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: + def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up ddp environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -269,28 +267,28 @@ def __collect_rank_zero_results( else: self.add_to_queue(trainer, extra) - return best_model_path, last_path, results, extra + return _SpawnOutput(best_model_path, last_path, results, extra) - def __recover_results_in_main_process( - self, best_path: Optional[str], last_path: Optional[str], extra: "_FakeQueue", trainer - ) -> None: + def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: - self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also best score + self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path + # TODO: pass also best score # load last weights - if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: - ckpt = self.checkpoint_io.load_checkpoint(last_path, map_location=(lambda storage, loc: storage)) + if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + ckpt = self.checkpoint_io.load_checkpoint( + spawn_output.last_path, map_location=(lambda storage, loc: storage) + ) self.lightning_module.load_state_dict(ckpt) # get the `callback_metrics` and set it to the trainer if is_overridden("get_from_queue", self.lightning_module): # only in case the user does not override it. # TODO: Remove the if in v1.7 - self.lightning_module.get_from_queue(extra) + self.lightning_module.get_from_queue(spawn_output.extra) else: - self.get_from_queue(trainer, extra) + self.get_from_queue(trainer, spawn_output.extra) def barrier(self, *args, **kwargs) -> None: if not distributed_available(): @@ -413,3 +411,10 @@ def put(self, item: Any) -> None: def empty(self) -> bool: return len(self) == 0 + + +class _SpawnOutput(NamedTuple): + best_model_path: Optional[_PATH] + last_path: Optional[_PATH] + trainer_results: Any + extra: _FakeQueue diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7679fbffa8e50..73b6e9f8a39b9 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -29,7 +29,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin, _SpawnOutput from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters @@ -202,9 +202,7 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __collect_rank_zero_results( - self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]: + def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up tpu spawn environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -230,7 +228,7 @@ def __collect_rank_zero_results( else: self.add_to_queue(trainer, extra) - return best_model_path, last_path, results, extra + return _SpawnOutput(best_model_path, last_path, results, extra) def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: @@ -274,7 +272,7 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st "start_method": self.start_method, } - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) From 213b447278153e2eb52badd535933bcb551d5e2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Dec 2021 21:27:59 +0000 Subject: [PATCH 40/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index f42edd0ae7763..7620329b60e7b 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union, NamedTuple +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union import numpy as np import torch @@ -45,7 +45,7 @@ from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT, _PATH +from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 73b6e9f8a39b9..b0284c88d6566 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -29,7 +29,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin, _SpawnOutput +from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters From 257924726ae244564e81e340414eaf86a945bea2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 10:17:14 +0100 Subject: [PATCH 41/41] inherit from UserList --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 7620329b60e7b..563f39a1f0cf4 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -14,6 +14,7 @@ import logging import os import re +from collections import UserList from multiprocessing.queues import SimpleQueue from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union @@ -400,7 +401,7 @@ def teardown(self) -> None: torch.cuda.empty_cache() -class _FakeQueue(list): +class _FakeQueue(UserList): """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" def get(self) -> Any: