|
33 | 33 | from pytorch_lightning.trainer.connectors.data_connector import DataConnector |
34 | 34 | from pytorch_lightning.trainer.states import TrainerFn |
35 | 35 | from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters |
36 | | -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device |
| 36 | +from pytorch_lightning.utilities.apply_func import move_data_to_device |
37 | 37 | from pytorch_lightning.utilities.data import has_len |
38 | 38 | from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp |
39 | 39 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
@@ -128,17 +128,6 @@ def setup(self, trainer: "pl.Trainer") -> None: |
128 | 128 | self.setup_optimizers(trainer) |
129 | 129 | self.setup_precision_plugin() |
130 | 130 |
|
131 | | - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: |
132 | | - """Moves the state of the optimizers to the TPU if needed.""" |
133 | | - # TODO: `self.root_device` would raise error if called outside the spawn process |
134 | | - # while training on 8 and more cores. |
135 | | - if device: |
136 | | - raise ValueError(f"device should be None" f" found: {device}.") |
137 | | - device = device or self.root_device |
138 | | - for opt in self.optimizers: |
139 | | - for p, v in opt.state.items(): |
140 | | - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) |
141 | | - |
142 | 131 | def _setup_model(self, model: Module) -> Module: |
143 | 132 | return model |
144 | 133 |
|
|
0 commit comments