Skip to content

Commit 8f85576

Browse files
committed
temp change to run tpu test
1 parent fdbbc08 commit 8f85576

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
3333
from pytorch_lightning.trainer.states import TrainerFn
3434
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters
35-
from pytorch_lightning.utilities.apply_func import move_data_to_device
35+
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
3636
from pytorch_lightning.utilities.data import has_len
3737
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
3838
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -127,6 +127,14 @@ def setup(self, trainer: "pl.Trainer") -> None:
127127
self.setup_optimizers(trainer)
128128
self.setup_precision_plugin()
129129

130+
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
131+
"""Moves the state of the optimizers to the TPU if needed."""
132+
# TODO: `self.root_device` would raise error if called outside the spawn process
133+
# while training on 8 and more cores.
134+
for opt in self.optimizers:
135+
for p, v in opt.state.items():
136+
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
137+
130138
def _setup_model(self, model: Module) -> Module:
131139
return model
132140

0 commit comments

Comments
 (0)