|
14 | 14 | import os |
15 | 15 | from typing import Any, Dict, Optional |
16 | 16 |
|
17 | | -import torch |
18 | | - |
19 | 17 | import pytorch_lightning as pl |
20 | 18 | from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO |
21 | 19 | from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO |
22 | 20 | from pytorch_lightning.plugins.precision import PrecisionPlugin |
23 | 21 | from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin |
24 | 22 | from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters |
25 | | -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device |
26 | 23 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
27 | 24 | from pytorch_lightning.utilities.model_helpers import is_overridden |
28 | 25 | from pytorch_lightning.utilities.types import _PATH |
@@ -66,14 +63,6 @@ def setup(self, trainer: "pl.Trainer") -> None: |
66 | 63 | self.setup_optimizers(trainer) |
67 | 64 | self.setup_precision_plugin() |
68 | 65 |
|
69 | | - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: |
70 | | - """Moves the state of the optimizers to the TPU if needed.""" |
71 | | - # TODO: `self.root_device` would raise error if called outside the spawn process |
72 | | - # while training on 8 and more cores. |
73 | | - for opt in self.optimizers: |
74 | | - for p, v in opt.state.items(): |
75 | | - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) |
76 | | - |
77 | 66 | def model_to_device(self) -> None: |
78 | 67 | self.model.to(self.root_device) |
79 | 68 |
|
|
0 commit comments