Skip to content

Commit cec2d79

Browse files
four4fishawaelchlipre-commit-ci[bot]carmocca
authored
3/n Move accelerator into Strategy (#11022)
* remove training_step() from accelerator * remove test, val, predict step * move * wip * accelerator references * cpu training * rename occurrences in tests * update tests * pull from adrian's commit * fix changelog merge pro * fix accelerator_connector and other updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix doc build and some mypy * fix lite * fix gpu setup environment * support customized ttp and accelerator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tpu error check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix precision_plugin initialization to recognisze cusomized plugin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update bug_report_model.py * Update accelerator_connector.py * update changelog * allow shorthand typing references to pl.Accelerator * rename helper method and add docstring * fix typing * Update pytorch_lightning/trainer/connectors/accelerator_connector.py Co-authored-by: Carlos Mocholí <[email protected]> * Update tests/accelerators/test_accelerator_connector.py Co-authored-by: Carlos Mocholí <[email protected]> * Update tests/accelerators/test_cpu.py Co-authored-by: Carlos Mocholí <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix pre commit complaint * update typing to long ugly path * spacing in flow diagram * remove todo comments * docformatter * Update pytorch_lightning/plugins/training_type/training_type_plugin.py * revert test changes * improve custom plugin examples * remove redundant call to ttp attribute it is no longer a property * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 9e56290 commit cec2d79

32 files changed

+200
-229
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
114114
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))
115115

116116

117+
- Moved ownership of the `Accelerator` instance to the `TrainingTypePlugin`; all training-type plugins now take an optional parameter `accelerator` ([#11022](https://github.com/PyTorchLightning/pytorch-lightning/pull/11022))
118+
117119

118120
### Deprecated
119121

docs/source/extensions/accelerators.rst

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,10 @@ One to handle differences from the training routine and one to handle different
2525
from pytorch_lightning.accelerators import GPUAccelerator
2626
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin
2727

28-
accelerator = GPUAccelerator(
29-
precision_plugin=NativeMixedPrecisionPlugin(precision=16, device="cuda"),
30-
training_type_plugin=DDPPlugin(),
31-
)
32-
trainer = Trainer(accelerator=accelerator)
28+
accelerator = GPUAccelerator()
29+
precision_plugin = NativeMixedPrecisionPlugin(precision=16, device="cuda")
30+
training_type_plugin = DDPPlugin(accelerator=accelerator, precision_plugin=precision_plugin)
31+
trainer = Trainer(strategy=training_type_plugin)
3332

3433

3534
We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new

docs/source/extensions/plugins.rst

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,10 @@ can then be passed into the Trainer directly or via a (custom) accelerator:
8080
trainer = Trainer(strategy=CustomDDPPlugin(), plugins=[CustomPrecisionPlugin()])
8181
8282
# fully custom accelerator and plugins
83-
accelerator = MyAccelerator(
84-
precision_plugin=CustomPrecisionPlugin(),
85-
training_type_plugin=CustomDDPPlugin(),
86-
)
87-
trainer = Trainer(accelerator=accelerator)
83+
accelerator = MyAccelerator()
84+
precision_plugin = MyPrecisionPlugin()
85+
training_type_plugin = CustomDDPPlugin(accelerator=accelerator, precision_plugin=precision_plugin)
86+
trainer = Trainer(strategy=training_type_plugin)
8887
8988
9089
The full list of built-in plugins is listed below.

pytorch_lightning/accelerators/accelerator.py

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import abstractmethod
15-
from typing import Any, Dict, Optional, Union
15+
from typing import Any, Dict, Union
1616

1717
import torch
18-
from torch.nn import Module
1918

2019
import pytorch_lightning as pl
21-
from pytorch_lightning.plugins.precision import PrecisionPlugin
22-
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
2320

2421

2522
class Accelerator:
@@ -31,76 +28,21 @@ class Accelerator:
3128
- GPU
3229
- TPU
3330
- IPU
34-
35-
Each Accelerator gets two plugins upon initialization:
36-
One to handle differences from the training routine and one to handle different precisions.
3731
"""
3832

39-
def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_plugin: TrainingTypePlugin) -> None:
40-
"""
41-
Args:
42-
precision_plugin: the plugin to handle precision-specific parts
43-
44-
.. deprecated::
45-
The ``precision_plugin`` parameter has been deprecated and will be removed soon.
46-
Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead.
47-
48-
training_type_plugin: the plugin to handle different training routines
49-
"""
50-
51-
self.training_type_plugin = training_type_plugin
52-
53-
if precision_plugin is not None:
54-
self.training_type_plugin._precision_plugin = precision_plugin
55-
56-
def setup_environment(self) -> None:
33+
def setup_environment(self, root_device: torch.device) -> None:
5734
"""Setup any processes or distributed connections.
5835
5936
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator
6037
environment before setup is complete.
6138
"""
62-
self.training_type_plugin.setup_environment()
6339

6440
def setup(self, trainer: "pl.Trainer") -> None:
6541
"""Setup plugins for the trainer fit and creates optimizers.
6642
6743
Args:
6844
trainer: the trainer instance
6945
"""
70-
self.training_type_plugin.setup(trainer)
71-
72-
@property
73-
def model(self) -> Module:
74-
"""Returns the model.
75-
76-
This can also be a wrapped LightningModule. For retrieving the pure LightningModule use
77-
:attr:`Accelerator.lightning_module`
78-
"""
79-
return self.training_type_plugin.model
80-
81-
@model.setter
82-
def model(self, new_model: Module) -> None:
83-
self.training_type_plugin.model = new_model
84-
85-
@property
86-
def lightning_module(self) -> "pl.LightningModule":
87-
"""Returns the pure LightningModule.
88-
89-
To get the potentially wrapped model use :attr:`Accelerator.model`
90-
"""
91-
return self.training_type_plugin.lightning_module
92-
93-
@property
94-
def root_device(self) -> torch.device:
95-
"""Returns the root device."""
96-
return self.training_type_plugin.root_device
97-
98-
def teardown(self) -> None:
99-
"""This method is called to teardown the training process.
100-
101-
It is the right place to release memory and free other resources.
102-
"""
103-
self.training_type_plugin.teardown()
10446

10547
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
10648
"""Gets stats for a given device.

pytorch_lightning/accelerators/cpu.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,21 @@
1515

1616
import torch
1717

18-
import pytorch_lightning as pl
1918
from pytorch_lightning.accelerators.accelerator import Accelerator
2019
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2120

2221

2322
class CPUAccelerator(Accelerator):
2423
"""Accelerator for CPU devices."""
2524

26-
def setup(self, trainer: "pl.Trainer") -> None:
25+
def setup_environment(self, root_device: torch.device) -> None:
2726
"""
2827
Raises:
2928
MisconfigurationException:
3029
If the selected device is not CPU.
3130
"""
32-
if "cpu" not in str(self.training_type_plugin.root_device):
33-
raise MisconfigurationException(
34-
f"Device should be CPU, got {self.training_type_plugin.root_device} instead."
35-
)
36-
37-
return super().setup(trainer)
31+
if "cpu" not in str(root_device):
32+
raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.")
3833

3934
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
4035
"""CPU device stats aren't supported yet."""

pytorch_lightning/accelerators/gpu.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,19 @@
3030
class GPUAccelerator(Accelerator):
3131
"""Accelerator for GPU devices."""
3232

33-
def setup_environment(self) -> None:
33+
def setup_environment(self, root_device: torch.device) -> None:
3434
"""
3535
Raises:
3636
MisconfigurationException:
3737
If the selected device is not GPU.
3838
"""
39-
super().setup_environment()
40-
if "cuda" not in str(self.training_type_plugin.root_device):
41-
raise MisconfigurationException(
42-
f"Device should be GPU, got {self.training_type_plugin.root_device} instead"
43-
)
44-
torch.cuda.set_device(self.training_type_plugin.root_device)
39+
if "cuda" not in str(root_device):
40+
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
41+
torch.cuda.set_device(root_device)
4542

4643
def setup(self, trainer: "pl.Trainer") -> None:
44+
# TODO refactor input from trainer to local_rank @four4fish
4745
self.set_nvidia_flags(trainer.local_rank)
48-
super().setup(trainer)
4946
# clear cache before training
5047
torch.cuda.empty_cache()
5148

@@ -74,10 +71,6 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
7471
return torch.cuda.memory_stats(device)
7572
return get_nvidia_gpu_stats(device)
7673

77-
def teardown(self) -> None:
78-
super().teardown()
79-
self.training_type_plugin._move_optimizer_state(torch.device("cpu"))
80-
8174
@staticmethod
8275
def auto_device_count() -> int:
8376
"""Get the devices when set to auto."""

pytorch_lightning/accelerators/tpu.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515

1616
import torch
1717

18-
import pytorch_lightning as pl
1918
from pytorch_lightning.accelerators.accelerator import Accelerator
20-
from pytorch_lightning.plugins.precision import TPUPrecisionPlugin
21-
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
22-
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
2319
from pytorch_lightning.utilities import _XLA_AVAILABLE
2420

2521
if _XLA_AVAILABLE:
@@ -29,25 +25,6 @@
2925
class TPUAccelerator(Accelerator):
3026
"""Accelerator for TPU devices."""
3127

32-
def setup(self, trainer: "pl.Trainer") -> None:
33-
"""
34-
Raises:
35-
ValueError:
36-
If the precision or training type plugin are unsupported.
37-
"""
38-
if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin):
39-
# this configuration should have been avoided in the accelerator connector
40-
raise ValueError(
41-
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`,"
42-
f" found: {self.training_type_plugin.precision_plugin}."
43-
)
44-
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
45-
raise ValueError(
46-
"The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin,"
47-
f" found {self.training_type_plugin}."
48-
)
49-
return super().setup(trainer)
50-
5128
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
5229
"""Gets stats for the given TPU device.
5330

pytorch_lightning/lite/lite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def __init__(
9999
amp_level=None,
100100
plugins=plugins,
101101
)
102-
self._accelerator = self._accelerator_connector.accelerator
103-
self._strategy = self._accelerator.training_type_plugin
102+
self._strategy = self._accelerator_connector.training_type_plugin
103+
self._accelerator = self._strategy.accelerator
104104
self._precision_plugin = self._strategy.precision_plugin
105105
self._models_setup: int = 0
106106

@@ -398,7 +398,7 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
398398
return seed_everything(seed=seed, workers=workers)
399399

400400
def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
401-
self._accelerator.setup_environment()
401+
self._strategy.setup_environment()
402402

403403
# apply sharded context to prevent OOM
404404
run_method = partial(self._run_with_sharded_context, run_method)

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class DDPPlugin(ParallelPlugin):
8484

8585
def __init__(
8686
self,
87+
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
8788
parallel_devices: Optional[List[torch.device]] = None,
8889
cluster_environment: Optional[ClusterEnvironment] = None,
8990
checkpoint_io: Optional[CheckpointIO] = None,
@@ -95,6 +96,7 @@ def __init__(
9596
**kwargs: Union[Any, Dict[str, Any]],
9697
) -> None:
9798
super().__init__(
99+
accelerator=accelerator,
98100
parallel_devices=parallel_devices,
99101
cluster_environment=cluster_environment,
100102
checkpoint_io=checkpoint_io,
@@ -147,6 +149,7 @@ def setup_environment(self) -> None:
147149
self._call_children_scripts()
148150

149151
self.setup_distributed()
152+
super().setup_environment()
150153

151154
def _setup_model(self, model: Module) -> DistributedDataParallel:
152155
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class DDPSpawnPlugin(ParallelPlugin):
6262

6363
def __init__(
6464
self,
65+
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
6566
parallel_devices: Optional[List[torch.device]] = None,
6667
cluster_environment: Optional[ClusterEnvironment] = None,
6768
checkpoint_io: Optional[CheckpointIO] = None,
@@ -72,6 +73,7 @@ def __init__(
7273
**kwargs: Any,
7374
):
7475
super().__init__(
76+
accelerator=accelerator,
7577
parallel_devices=parallel_devices,
7678
cluster_environment=cluster_environment,
7779
checkpoint_io=checkpoint_io,

0 commit comments

Comments
 (0)