Skip to content

Commit fdbbc08

Browse files
committed
Update tpu tp share same logic with ttp
1 parent 38ed26e commit fdbbc08

File tree

2 files changed

+1
-20
lines changed

2 files changed

+1
-20
lines changed

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,12 @@
1414
import os
1515
from typing import Any, Dict, Optional
1616

17-
import torch
18-
1917
import pytorch_lightning as pl
2018
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2119
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
2220
from pytorch_lightning.plugins.precision import PrecisionPlugin
2321
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
2422
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
25-
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
2623
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2724
from pytorch_lightning.utilities.model_helpers import is_overridden
2825
from pytorch_lightning.utilities.types import _PATH
@@ -66,14 +63,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
6663
self.setup_optimizers(trainer)
6764
self.setup_precision_plugin()
6865

69-
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
70-
"""Moves the state of the optimizers to the TPU if needed."""
71-
# TODO: `self.root_device` would raise error if called outside the spawn process
72-
# while training on 8 and more cores.
73-
for opt in self.optimizers:
74-
for p, v in opt.state.items():
75-
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
76-
7766
def model_to_device(self) -> None:
7867
self.model.to(self.root_device)
7968

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
3333
from pytorch_lightning.trainer.states import TrainerFn
3434
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters
35-
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
35+
from pytorch_lightning.utilities.apply_func import move_data_to_device
3636
from pytorch_lightning.utilities.data import has_len
3737
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
3838
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -127,14 +127,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
127127
self.setup_optimizers(trainer)
128128
self.setup_precision_plugin()
129129

130-
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
131-
"""Moves the state of the optimizers to the TPU if needed."""
132-
# TODO: `self.root_device` would raise error if called outside the spawn process
133-
# while training on 8 and more cores.
134-
for opt in self.optimizers:
135-
for p, v in opt.state.items():
136-
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
137-
138130
def _setup_model(self, model: Module) -> Module:
139131
return model
140132

0 commit comments

Comments
 (0)