diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e29976caee57..77e5166f6ad59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867)) +- Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849)) + + - Removed `model_sharded_context` method from `Accelerator` ([#10886](https://github.com/PyTorchLightning/pytorch-lightning/pull/10886)) diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index f9fa415e67090..11f0c60954720 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -14,15 +14,12 @@ import os from typing import Any, Dict, Optional -import torch - import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _PATH @@ -66,14 +63,6 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_optimizers(trainer) self.setup_precision_plugin() - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the TPU if needed.""" - # TODO: `self.root_device` would raise error if called outside the spawn process - # while training on 8 and more cores. - for opt in self.optimizers: - for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) - def model_to_device(self) -> None: self.model.to(self.root_device) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 6f332ee4fe7fe..79f4797eda297 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -33,7 +33,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -60,7 +60,7 @@ def __init__( checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, - **_: Any + **_: Any, ) -> None: checkpoint_io = checkpoint_io or XLACheckpointIO() super().__init__( @@ -128,14 +128,6 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_optimizers(trainer) self.setup_precision_plugin() - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the TPU if needed.""" - # TODO: `self.root_device` would raise error if called outside the spawn process - # while training on 8 and more cores. - for opt in self.optimizers: - for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) - def _setup_model(self, model: Module) -> Module: return model diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 05d444a849df3..9334df4b18b60 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -107,11 +107,12 @@ def setup_precision_plugin(self) -> None: self.lr_schedulers = schedulers def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the GPU if needed.""" - device = device or self.root_device + """Moves the state of the optimizers to the appropriate device if needed.""" for opt in self.optimizers: for p, v in opt.state.items(): - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) + # `self.root_device` would raise error if called outside the spawn process + # while training on 8 and more cores. + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device) def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """Returns state of an optimizer.