Skip to content

Commit 1d28785

Browse files
four4fishdaniellepintzcarmoccatchatonawaelchli
authored
2/n Move Precision Plugin into strategy - move optimizer related logics (#10596)
Co-authored-by: Danielle Pintz <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: thomas chaton <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent ce95891 commit 1d28785

File tree

28 files changed

+227
-215
lines changed

28 files changed

+227
-215
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6868
- Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541))
6969

7070

71+
- Moved optimizer related logics from `Accelerator` to `TrainingTypePlugin` ([#10596](https://github.com/PyTorchLightning/pytorch-lightning/pull/10596))
72+
73+
7174
- Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649))
7275

7376

pytorch_lightning/accelerators/accelerator.py

Lines changed: 5 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,14 @@
1313
# limitations under the License.
1414
import contextlib
1515
from abc import abstractmethod
16-
from typing import Any, Callable, Dict, Generator, List, Optional, Union
16+
from typing import Any, Dict, Generator, Optional, Union
1717

1818
import torch
19-
from torch import Tensor
20-
from torch.cuda.amp import GradScaler
2119
from torch.nn import Module
22-
from torch.optim import Optimizer
2320

2421
import pytorch_lightning as pl
25-
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
22+
from pytorch_lightning.plugins.precision import PrecisionPlugin
2623
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
27-
from pytorch_lightning.trainer.states import TrainerFn
28-
from pytorch_lightning.utilities import rank_zero_deprecation
29-
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
30-
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
3124
from pytorch_lightning.utilities.types import STEP_OUTPUT
3225

3326

@@ -62,10 +55,6 @@ def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_pl
6255
if precision_plugin is not None:
6356
self.training_type_plugin._precision_plugin = precision_plugin
6457

65-
self.optimizers: List = []
66-
self.lr_schedulers: List = []
67-
self.optimizer_frequencies: List = []
68-
6958
def setup_environment(self) -> None:
7059
"""Setup any processes or distributed connections.
7160
@@ -80,28 +69,18 @@ def setup(self, trainer: "pl.Trainer") -> None:
8069
Args:
8170
trainer: the trainer instance
8271
"""
83-
self.setup_training_type_plugin()
84-
if not self.training_type_plugin.setup_optimizers_in_pre_dispatch:
85-
self.setup_optimizers(trainer)
86-
self.setup_precision_plugin()
72+
self.training_type_plugin.setup(trainer)
8773

8874
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
8975
"""Hook to do something before the training/evaluation/prediction starts."""
90-
self._move_optimizer_state()
76+
self.training_type_plugin._move_optimizer_state()
9177

9278
self.training_type_plugin.pre_dispatch()
9379
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
94-
self.setup_optimizers(trainer)
80+
self.training_type_plugin.setup_optimizers(trainer)
9581

9682
self.training_type_plugin.precision_plugin.pre_dispatch()
9783

98-
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
99-
"""Moves the state of the optimizers to the GPU if needed."""
100-
device = device or self.root_device
101-
for opt in self.optimizers:
102-
for p, v in opt.state.items():
103-
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device)
104-
10584
def dispatch(self, trainer: "pl.Trainer") -> None:
10685
"""Hook to do something before the training/evaluation/prediction starts."""
10786
self.training_type_plugin.dispatch(trainer)
@@ -177,115 +156,12 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
177156
with self.training_type_plugin.precision_plugin.predict_step_context():
178157
return self.training_type_plugin.predict_step(*step_kwargs.values())
179158

180-
def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
181-
"""Forwards backward-calls to the precision plugin.
182-
183-
Args:
184-
closure_loss: a tensor holding the loss value to backpropagate
185-
"""
186-
self.training_type_plugin.pre_backward(closure_loss)
187-
closure_loss = self.training_type_plugin.precision_plugin.pre_backward(self.lightning_module, closure_loss)
188-
189-
self.training_type_plugin.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
190-
191-
closure_loss = self.training_type_plugin.precision_plugin.post_backward(self.lightning_module, closure_loss)
192-
self.training_type_plugin.post_backward(closure_loss)
193-
194-
return closure_loss
195-
196-
def optimizer_step(
197-
self,
198-
optimizer: Optimizer,
199-
opt_idx: int,
200-
closure: Callable[[], Any],
201-
model: Optional[Union["pl.LightningModule", Module]] = None,
202-
**kwargs: Any,
203-
) -> None:
204-
"""performs the actual optimizer step.
205-
206-
Args:
207-
optimizer: the optimizer performing the step
208-
opt_idx: index of the current optimizer
209-
closure: closure calculating the loss value
210-
model: reference to the model, optionally defining optimizer step related hooks
211-
**kwargs: Any extra arguments to ``optimizer.step``
212-
"""
213-
model = model or self.lightning_module
214-
self.training_type_plugin.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
215-
216-
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
217-
"""Zeros all model parameter's gradients."""
218-
model_ref = self.lightning_module
219-
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
220-
221-
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
222-
"""Creates optimizers and schedulers.
223-
224-
Args:
225-
trainer: the Trainer, these optimizers should be connected to
226-
"""
227-
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
228-
return
229-
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
230-
trainer=trainer, model=self.lightning_module
231-
)
232-
self.optimizers = optimizers
233-
self.lr_schedulers = lr_schedulers
234-
self.optimizer_frequencies = optimizer_frequencies
235-
236-
def setup_training_type_plugin(self) -> None:
237-
"""Attaches the training type plugin to the accelerator."""
238-
self.training_type_plugin.setup()
239-
240-
def setup_precision_plugin(self) -> None:
241-
"""Attaches the precision plugin to the accelerator."""
242-
model, optimizers, schedulers = self.training_type_plugin.precision_plugin.connect(
243-
self.model, self.optimizers, self.lr_schedulers
244-
)
245-
self.model = model
246-
self.optimizers = optimizers
247-
self.lr_schedulers = schedulers
248-
249-
@property
250-
def amp_backend(self) -> Optional[LightningEnum]:
251-
if isinstance(self.training_type_plugin.precision_plugin, ApexMixedPrecisionPlugin):
252-
return AMPType.APEX
253-
if isinstance(self.training_type_plugin.precision_plugin, NativeMixedPrecisionPlugin):
254-
return AMPType.NATIVE
255-
return None
256-
257-
@property
258-
def precision(self) -> Union[str, int]:
259-
"""The type of precision being used with this accelerator.
260-
261-
.. deprecated::
262-
This property been deprecated and will be removed soon.
263-
Use ``training_type_plugin.precision_plugin.precision`` instead.
264-
"""
265-
rank_zero_deprecation(
266-
f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon"
267-
f" Use `training_type_plugin.precision_plugin.precision` instead."
268-
)
269-
return self.training_type_plugin.precision_plugin.precision
270-
271-
@property
272-
def scaler(self) -> Optional["GradScaler"]:
273-
return getattr(self.training_type_plugin.precision_plugin, "scaler", None)
274-
275-
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
276-
"""Returns state of an optimizer.
277-
278-
Allows for syncing/collating optimizer state from processes in custom plugins.
279-
"""
280-
return getattr(self.training_type_plugin, "optimizer_state", lambda x: x.state_dict())(optimizer)
281-
282159
@contextlib.contextmanager
283160
def model_sharded_context(self) -> Generator[None, None, None]:
284161
"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to.
285162
286163
shard the model instantly - useful for extremely large models. Can save memory and
287164
initialization time.
288-
289165
Returns:
290166
Model parallel context.
291167
"""

pytorch_lightning/accelerators/cpu.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ def setup(self, trainer: "pl.Trainer") -> None:
2929
MisconfigurationException:
3030
If the selected device is not CPU.
3131
"""
32-
if "cpu" not in str(self.root_device):
33-
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.")
32+
if "cpu" not in str(self.training_type_plugin.root_device):
33+
raise MisconfigurationException(
34+
f"Device should be CPU, got {self.training_type_plugin.root_device} instead."
35+
)
3436

3537
return super().setup(trainer)
3638

pytorch_lightning/accelerators/gpu.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ def setup_environment(self) -> None:
3737
If the selected device is not GPU.
3838
"""
3939
super().setup_environment()
40-
if "cuda" not in str(self.root_device):
41-
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
42-
torch.cuda.set_device(self.root_device)
40+
if "cuda" not in str(self.training_type_plugin.root_device):
41+
raise MisconfigurationException(
42+
f"Device should be GPU, got {self.training_type_plugin.root_device} instead"
43+
)
44+
torch.cuda.set_device(self.training_type_plugin.root_device)
4345

4446
def setup(self, trainer: "pl.Trainer") -> None:
4547
self.set_nvidia_flags(trainer.local_rank)
@@ -77,7 +79,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
7779

7880
def teardown(self) -> None:
7981
super().teardown()
80-
self._move_optimizer_state(torch.device("cpu"))
82+
self.training_type_plugin._move_optimizer_state(torch.device("cpu"))
8183

8284
@staticmethod
8385
def auto_device_count() -> int:

pytorch_lightning/accelerators/ipu.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,12 @@
1515

1616
import torch
1717

18-
import pytorch_lightning as pl
1918
from pytorch_lightning.accelerators.accelerator import Accelerator
20-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2119

2220

2321
class IPUAccelerator(Accelerator):
2422
"""Accelerator for IPUs."""
2523

26-
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
27-
"""
28-
Raises:
29-
MisconfigurationException:
30-
If multiple optimizers are provided.
31-
"""
32-
super().setup_optimizers(trainer)
33-
34-
if len(self.optimizers) > 1:
35-
raise MisconfigurationException("IPUs currently only support one optimizer.")
36-
3724
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
3825
"""IPU device stats aren't supported yet."""
3926
return {}

pytorch_lightning/accelerators/tpu.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, Optional, Union
14+
from typing import Any, Dict, Union
1515

1616
import torch
1717

@@ -21,7 +21,6 @@
2121
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
2222
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
2323
from pytorch_lightning.utilities import _XLA_AVAILABLE
24-
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
2524

2625
if _XLA_AVAILABLE:
2726
import torch_xla.core.xla_model as xm
@@ -49,14 +48,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
4948
)
5049
return super().setup(trainer)
5150

52-
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
53-
"""Moves the state of the optimizers to the TPU if needed."""
54-
# TODO: `self.root_device` would raise error if called outside the spawn process
55-
# while training on 8 and more cores.
56-
for opt in self.optimizers:
57-
for p, v in opt.state.items():
58-
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
59-
6051
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
6152
"""Gets stats for the given TPU device.
6253

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1347,7 +1347,7 @@ def training_step(...):
13471347
**kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
13481348
"""
13491349
self._verify_is_manual_optimization("manual_backward")
1350-
self.trainer.accelerator.backward(loss, None, None, *args, **kwargs)
1350+
self.trainer.training_type_plugin.backward(loss, None, None, *args, **kwargs)
13511351

13521352
def backward(
13531353
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs

pytorch_lightning/core/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,4 @@ def closure_dis():
161161
trainer = self._trainer
162162
assert trainer is not None
163163
with trainer.profiler.profile(profiler_action):
164-
trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
164+
trainer.training_type_plugin.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)

pytorch_lightning/lite/lite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def device(self) -> torch.device:
112112
113113
Use this to create tensors directly on the device if needed.
114114
"""
115-
return self._accelerator.root_device
115+
return self._strategy.root_device
116116

117117
@property
118118
def global_rank(self) -> int:

pytorch_lightning/lite/wrappers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None:
4646
self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
4747
self._optimizer = optimizer
4848
self._accelerator = accelerator
49+
# TODO (@awaelchli) refactor to take Strategy as param
50+
self._strategy = self._accelerator.training_type_plugin
4951

5052
@property
5153
def optimizer(self) -> Optimizer:
@@ -56,11 +58,11 @@ def state_dict(self) -> Dict[str, Tensor]:
5658

5759
def step(self, closure: Optional[Callable] = None) -> None:
5860
closure = closure or _do_nothing_closure
59-
self._accelerator.optimizer_step(
61+
self._strategy.optimizer_step(
6062
self.optimizer,
6163
opt_idx=0,
6264
closure=closure,
63-
model=self._accelerator.model,
65+
model=self._strategy.model,
6466
)
6567

6668

0 commit comments

Comments
 (0)