Skip to content

Commit d70cac3

Browse files
committed
Add changelog
1 parent 4c8ba27 commit d70cac3

File tree

3 files changed

+5
-14
lines changed

3 files changed

+5
-14
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
188188
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))
189189

190190

191+
- Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849))
192+
193+
191194
### Fixed
192195

193196
- Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611))

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
3434
from pytorch_lightning.trainer.states import TrainerFn
3535
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters
36-
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
36+
from pytorch_lightning.utilities.apply_func import move_data_to_device
3737
from pytorch_lightning.utilities.data import has_len
3838
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
3939
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -128,17 +128,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
128128
self.setup_optimizers(trainer)
129129
self.setup_precision_plugin()
130130

131-
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
132-
"""Moves the state of the optimizers to the TPU if needed."""
133-
# TODO: `self.root_device` would raise error if called outside the spawn process
134-
# while training on 8 and more cores.
135-
if device:
136-
raise ValueError(f"device should be None" f" found: {device}.")
137-
device = device or self.root_device
138-
for opt in self.optimizers:
139-
for p, v in opt.state.items():
140-
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device)
141-
142131
def _setup_model(self, model: Module) -> Module:
143132
return model
144133

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,9 @@ def setup_precision_plugin(self) -> None:
108108

109109
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
110110
"""Moves the state of the optimizers to the GPU if needed."""
111-
device = device or self.root_device
112111
for opt in self.optimizers:
113112
for p, v in opt.state.items():
114-
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device)
113+
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device)
115114

116115
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
117116
"""Returns state of an optimizer.

0 commit comments

Comments
 (0)