|
32 | 32 | from pytorch_lightning.trainer.connectors.data_connector import DataConnector |
33 | 33 | from pytorch_lightning.trainer.states import TrainerFn |
34 | 34 | from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters |
35 | | -from pytorch_lightning.utilities.apply_func import move_data_to_device |
| 35 | +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device |
36 | 36 | from pytorch_lightning.utilities.data import has_len |
37 | 37 | from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp |
38 | 38 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
@@ -127,6 +127,14 @@ def setup(self, trainer: "pl.Trainer") -> None: |
127 | 127 | self.setup_optimizers(trainer) |
128 | 128 | self.setup_precision_plugin() |
129 | 129 |
|
| 130 | + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: |
| 131 | + """Moves the state of the optimizers to the TPU if needed.""" |
| 132 | + # TODO: `self.root_device` would raise error if called outside the spawn process |
| 133 | + # while training on 8 and more cores. |
| 134 | + for opt in self.optimizers: |
| 135 | + for p, v in opt.state.items(): |
| 136 | + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) |
| 137 | + |
130 | 138 | def _setup_model(self, model: Module) -> Module: |
131 | 139 | return model |
132 | 140 |
|
|
0 commit comments