From 123d7344fa0ea29e9c0166709bbc1dda2ff4fe95 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Tue, 30 Nov 2021 10:11:32 -0800 Subject: [PATCH 1/9] Update tpu tp share same logic with ttp --- pytorch_lightning/plugins/training_type/single_tpu.py | 11 ----------- pytorch_lightning/plugins/training_type/tpu_spawn.py | 10 +--------- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index f9fa415e67090..11f0c60954720 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -14,15 +14,12 @@ import os from typing import Any, Dict, Optional -import torch - import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _PATH @@ -66,14 +63,6 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_optimizers(trainer) self.setup_precision_plugin() - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the TPU if needed.""" - # TODO: `self.root_device` would raise error if called outside the spawn process - # while training on 8 and more cores. - for opt in self.optimizers: - for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) - def model_to_device(self) -> None: self.model.to(self.root_device) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 6f332ee4fe7fe..7eae2c60ac833 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -33,7 +33,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -128,14 +128,6 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_optimizers(trainer) self.setup_precision_plugin() - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the TPU if needed.""" - # TODO: `self.root_device` would raise error if called outside the spawn process - # while training on 8 and more cores. - for opt in self.optimizers: - for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) - def _setup_model(self, model: Module) -> Module: return model From 0fb7d21ff96fa9f0a5c640e92e0f182e401a2655 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Tue, 30 Nov 2021 17:52:22 -0800 Subject: [PATCH 2/9] run test --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7eae2c60ac833..be6bc73f050fe 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -33,7 +33,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters -from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -128,6 +128,14 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_optimizers(trainer) self.setup_precision_plugin() + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: + """Moves the state of the optimizers to the TPU if needed.""" + # TODO: `self.root_device` would raise error if called outside the spawn process + # while training on 8 and more cores. + for opt in self.optimizers: + for p, v in opt.state.items(): + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, xm.xla_device()) + def _setup_model(self, model: Module) -> Module: return model From 8c067cd1f98a3eda5be8dd496030013c6ea0e145 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Tue, 30 Nov 2021 21:54:12 -0800 Subject: [PATCH 3/9] Update tpu_spawn.py --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index be6bc73f050fe..7221f2a575eb5 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -132,9 +132,10 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the TPU if needed.""" # TODO: `self.root_device` would raise error if called outside the spawn process # while training on 8 and more cores. + device = device or self.root_device for opt in self.optimizers: for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, xm.xla_device()) + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) def _setup_model(self, model: Module) -> Module: return model From 4c8ba27fe0c29eea7e5766a1ec2a9ff53008830c Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Wed, 1 Dec 2021 11:04:50 -0800 Subject: [PATCH 4/9] debug --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7221f2a575eb5..66b5c3234fa04 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -60,7 +60,7 @@ def __init__( checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, - **_: Any + **_: Any, ) -> None: checkpoint_io = checkpoint_io or XLACheckpointIO() super().__init__( @@ -132,6 +132,8 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the TPU if needed.""" # TODO: `self.root_device` would raise error if called outside the spawn process # while training on 8 and more cores. + if device: + raise ValueError(f"device should be None" f" found: {device}.") device = device or self.root_device for opt in self.optimizers: for p, v in opt.state.items(): From d70cac3b692762a8873a8851b17dad2d2a036837 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Wed, 1 Dec 2021 11:55:49 -0800 Subject: [PATCH 5/9] Add changelog --- CHANGELOG.md | 3 +++ .../plugins/training_type/tpu_spawn.py | 13 +------------ .../plugins/training_type/training_type_plugin.py | 3 +-- 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 629b28e392792..3fb0ed21832b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867)) +- Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849)) + + ### Fixed - Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611)) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 66b5c3234fa04..79f4797eda297 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -33,7 +33,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -128,17 +128,6 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_optimizers(trainer) self.setup_precision_plugin() - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the TPU if needed.""" - # TODO: `self.root_device` would raise error if called outside the spawn process - # while training on 8 and more cores. - if device: - raise ValueError(f"device should be None" f" found: {device}.") - device = device or self.root_device - for opt in self.optimizers: - for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) - def _setup_model(self, model: Module) -> Module: return model diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 05d444a849df3..6ca5e83a43309 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -108,10 +108,9 @@ def setup_precision_plugin(self) -> None: def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the GPU if needed.""" - device = device or self.root_device for opt in self.optimizers: for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device) def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """Returns state of an optimizer. From a0da347f9467cc2bdcd0e873b5c00bb7bac40ba8 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 1 Dec 2021 14:55:34 -0800 Subject: [PATCH 6/9] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6ca5e83a43309..94cc9c66d1fb4 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -107,7 +107,7 @@ def setup_precision_plugin(self) -> None: self.lr_schedulers = schedulers def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the GPU if needed.""" + """Moves the state of the optimizers to the appropriate device if needed.""" for opt in self.optimizers: for p, v in opt.state.items(): opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device) From 327564f53d7d21f5784f9650020fcd304d468497 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 1 Dec 2021 15:04:26 -0800 Subject: [PATCH 7/9] Update training_type_plugin.py --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 94cc9c66d1fb4..422ea7ebfa3fb 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -110,6 +110,7 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the appropriate device if needed.""" for opt in self.optimizers: for p, v in opt.state.items(): + #`self.root_device` would raise error if called outside the spawn process while training on 8 and more cores. opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device) def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: From 91d00430a0ea938dacec857ed0b4afc04f1920a0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Dec 2021 23:06:00 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 422ea7ebfa3fb..3eb282f1b4aa8 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -110,7 +110,7 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the appropriate device if needed.""" for opt in self.optimizers: for p, v in opt.state.items(): - #`self.root_device` would raise error if called outside the spawn process while training on 8 and more cores. + # `self.root_device` would raise error if called outside the spawn process while training on 8 and more cores. opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device) def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: From 7515e234d2569a58387686f5f39bfa73f7540c4c Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 1 Dec 2021 15:41:34 -0800 Subject: [PATCH 9/9] Update training_type_plugin.py --- .../plugins/training_type/training_type_plugin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 3eb282f1b4aa8..9334df4b18b60 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -110,7 +110,8 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the appropriate device if needed.""" for opt in self.optimizers: for p, v in opt.state.items(): - # `self.root_device` would raise error if called outside the spawn process while training on 8 and more cores. + # `self.root_device` would raise error if called outside the spawn process + # while training on 8 and more cores. opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device) def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: