Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9d0a88f
remove training_step() from accelerator
four4fish Dec 4, 2021
ac313dc
remove test, val, predict step
four4fish Dec 4, 2021
c22ce58
move
awaelchli Dec 6, 2021
8ae530c
wip
awaelchli Dec 7, 2021
da00425
accelerator references
awaelchli Dec 8, 2021
7db6742
cpu training
awaelchli Dec 8, 2021
8c7fc95
rename occurrences in tests
awaelchli Dec 8, 2021
4afbf5c
update tests
awaelchli Dec 8, 2021
8fdce97
pull from adrian's commit
four4fish Dec 10, 2021
1c7bf4d
fix changelog merge pro
four4fish Dec 10, 2021
59920f7
fix accelerator_connector and other updates
four4fish Dec 10, 2021
7637a7c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2021
d378475
fix doc build and some mypy
four4fish Dec 10, 2021
2810731
fix lite
four4fish Dec 10, 2021
0f9d245
fix gpu setup environment
four4fish Dec 10, 2021
cc2648a
support customized ttp and accelerator
four4fish Dec 10, 2021
34b9544
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2021
46283e2
fix tpu error check
four4fish Dec 10, 2021
e6dfafe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2021
347a4f1
fix precision_plugin initialization to recognisze cusomized plugin
four4fish Dec 11, 2021
c0120d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2021
f08163b
Update bug_report_model.py
four4fish Dec 14, 2021
34c13ff
Update accelerator_connector.py
four4fish Dec 14, 2021
6bdb464
update changelog
awaelchli Dec 15, 2021
c039c68
allow shorthand typing references to pl.Accelerator
awaelchli Dec 15, 2021
0976c50
rename helper method and add docstring
awaelchli Dec 15, 2021
7b1738c
fix typing
awaelchli Dec 15, 2021
2f18893
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
awaelchli Dec 15, 2021
bf97a58
Update tests/accelerators/test_accelerator_connector.py
awaelchli Dec 15, 2021
e0f4a77
Update tests/accelerators/test_cpu.py
awaelchli Dec 15, 2021
5488519
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 15, 2021
b69537e
fix pre commit complaint
awaelchli Dec 15, 2021
94fe8f8
update typing to long ugly path
awaelchli Dec 15, 2021
19bcf3f
spacing in flow diagram
awaelchli Dec 15, 2021
5862afe
remove todo comments
four4fish Dec 15, 2021
2dfc443
docformatter
awaelchli Dec 15, 2021
92cc262
Update pytorch_lightning/plugins/training_type/training_type_plugin.py
awaelchli Dec 15, 2021
ff3e2dc
revert test changes
four4fish Dec 15, 2021
9f1eade
improve custom plugin examples
four4fish Dec 15, 2021
a74f4c1
remove redundant call to ttp attribute
awaelchli Dec 16, 2021
292c640
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2021
448b524
Apply suggestions from code review
four4fish Dec 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))


- Moved ownership of the `Accelerator` instance to the `TrainingTypePlugin`; all training-type plugins now take an optional parameter `accelerator` ([#11022](https://github.com/PyTorchLightning/pytorch-lightning/pull/11022))


### Deprecated

Expand Down
9 changes: 4 additions & 5 deletions docs/source/extensions/accelerators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ One to handle differences from the training routine and one to handle different
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin

accelerator = GPUAccelerator(
precision_plugin=NativeMixedPrecisionPlugin(precision=16, device="cuda"),
training_type_plugin=DDPPlugin(),
)
trainer = Trainer(accelerator=accelerator)
accelerator = GPUAccelerator()
precision_plugin = NativeMixedPrecisionPlugin(precision=16, device="cuda")
training_type_plugin = DDPPlugin(accelerator=accelerator, precision_plugin=precision_plugin)
trainer = Trainer(strategy=training_type_plugin)


We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new
Expand Down
9 changes: 4 additions & 5 deletions docs/source/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,10 @@ can then be passed into the Trainer directly or via a (custom) accelerator:
trainer = Trainer(strategy=CustomDDPPlugin(), plugins=[CustomPrecisionPlugin()])

# fully custom accelerator and plugins
accelerator = MyAccelerator(
precision_plugin=CustomPrecisionPlugin(),
training_type_plugin=CustomDDPPlugin(),
)
trainer = Trainer(accelerator=accelerator)
accelerator = MyAccelerator()
precision_plugin = MyPrecisionPlugin()
training_type_plugin = CustomDDPPlugin(accelerator=accelerator, precision_plugin=precision_plugin)
trainer = Trainer(strategy=training_type_plugin)


The full list of built-in plugins is listed below.
Expand Down
62 changes: 2 additions & 60 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Union

import torch
from torch.nn import Module

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin


class Accelerator:
Expand All @@ -31,76 +28,21 @@ class Accelerator:
- GPU
- TPU
- IPU
Each Accelerator gets two plugins upon initialization:
One to handle differences from the training routine and one to handle different precisions.
"""

def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_plugin: TrainingTypePlugin) -> None:
"""
Args:
precision_plugin: the plugin to handle precision-specific parts
.. deprecated::
The ``precision_plugin`` parameter has been deprecated and will be removed soon.
Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead.
training_type_plugin: the plugin to handle different training routines
"""

self.training_type_plugin = training_type_plugin

if precision_plugin is not None:
self.training_type_plugin._precision_plugin = precision_plugin

def setup_environment(self) -> None:
def setup_environment(self, root_device: torch.device) -> None:
"""Setup any processes or distributed connections.
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator
environment before setup is complete.
"""
self.training_type_plugin.setup_environment()

def setup(self, trainer: "pl.Trainer") -> None:
"""Setup plugins for the trainer fit and creates optimizers.
Args:
trainer: the trainer instance
"""
self.training_type_plugin.setup(trainer)

@property
def model(self) -> Module:
"""Returns the model.
This can also be a wrapped LightningModule. For retrieving the pure LightningModule use
:attr:`Accelerator.lightning_module`
"""
return self.training_type_plugin.model

@model.setter
def model(self, new_model: Module) -> None:
self.training_type_plugin.model = new_model

@property
def lightning_module(self) -> "pl.LightningModule":
"""Returns the pure LightningModule.
To get the potentially wrapped model use :attr:`Accelerator.model`
"""
return self.training_type_plugin.lightning_module

@property
def root_device(self) -> torch.device:
"""Returns the root device."""
return self.training_type_plugin.root_device

def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""
self.training_type_plugin.teardown()

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for a given device.
Expand Down
11 changes: 3 additions & 8 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,21 @@

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class CPUAccelerator(Accelerator):
"""Accelerator for CPU devices."""

def setup(self, trainer: "pl.Trainer") -> None:
def setup_environment(self, root_device: torch.device) -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not CPU.
"""
if "cpu" not in str(self.training_type_plugin.root_device):
raise MisconfigurationException(
f"Device should be CPU, got {self.training_type_plugin.root_device} instead."
)

return super().setup(trainer)
if "cpu" not in str(root_device):
raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.")

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""CPU device stats aren't supported yet."""
Expand Down
17 changes: 5 additions & 12 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,19 @@
class GPUAccelerator(Accelerator):
"""Accelerator for GPU devices."""

def setup_environment(self) -> None:
def setup_environment(self, root_device: torch.device) -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
super().setup_environment()
if "cuda" not in str(self.training_type_plugin.root_device):
raise MisconfigurationException(
f"Device should be GPU, got {self.training_type_plugin.root_device} instead"
)
torch.cuda.set_device(self.training_type_plugin.root_device)
if "cuda" not in str(root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
torch.cuda.set_device(root_device)

def setup(self, trainer: "pl.Trainer") -> None:
# TODO refactor input from trainer to local_rank @four4fish
self.set_nvidia_flags(trainer.local_rank)
super().setup(trainer)
# clear cache before training
torch.cuda.empty_cache()

Expand Down Expand Up @@ -74,10 +71,6 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
return torch.cuda.memory_stats(device)
return get_nvidia_gpu_stats(device)

def teardown(self) -> None:
super().teardown()
self.training_type_plugin._move_optimizer_state(torch.device("cpu"))

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
Expand Down
23 changes: 0 additions & 23 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import TPUPrecisionPlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE

if _XLA_AVAILABLE:
Expand All @@ -29,25 +25,6 @@
class TPUAccelerator(Accelerator):
"""Accelerator for TPU devices."""

def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
ValueError:
If the precision or training type plugin are unsupported.
"""
if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin):
# this configuration should have been avoided in the accelerator connector
raise ValueError(
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`,"
f" found: {self.training_type_plugin.precision_plugin}."
)
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
raise ValueError(
"The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin,"
f" found {self.training_type_plugin}."
)
return super().setup(trainer)

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for the given TPU device.

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(
amp_level=None,
plugins=plugins,
)
self._accelerator = self._accelerator_connector.accelerator
self._strategy = self._accelerator.training_type_plugin
self._strategy = self._accelerator_connector.training_type_plugin
self._accelerator = self._strategy.accelerator
self._precision_plugin = self._strategy.precision_plugin
self._models_setup: int = 0

Expand Down Expand Up @@ -398,7 +398,7 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
return seed_everything(seed=seed, workers=workers)

def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._accelerator.setup_environment()
self._strategy.setup_environment()

# apply sharded context to prevent OOM
run_method = partial(self._run_with_sharded_context, run_method)
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class DDPPlugin(ParallelPlugin):

def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
Expand All @@ -95,6 +96,7 @@ def __init__(
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
Expand Down Expand Up @@ -147,6 +149,7 @@ def setup_environment(self) -> None:
self._call_children_scripts()

self.setup_distributed()
super().setup_environment()

def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class DDPSpawnPlugin(ParallelPlugin):

def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
Expand All @@ -72,6 +73,7 @@ def __init__(
**kwargs: Any,
):
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class DeepSpeedPlugin(DDPPlugin):

def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
zero_optimization: bool = True,
stage: int = 2,
remote_device: str = "cpu",
Expand Down Expand Up @@ -273,6 +274,7 @@ def __init__(
)

super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
precision_plugin=precision_plugin,
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ class DataParallelPlugin(ParallelPlugin):

def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=None,
checkpoint_io=checkpoint_io,
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class DDPFullyShardedPlugin(DDPPlugin):

def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
cpu_offload: bool = False,
flatten_parameters: bool = True,
reshard_after_forward: bool = True,
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
"""

super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ class HorovodPlugin(ParallelPlugin):

def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=None,
checkpoint_io=checkpoint_io,
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class IPUPlugin(ParallelPlugin):

def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
device_iterations: int = 1,
autoreport: bool = False,
autoreport_dir: Optional[str] = None,
Expand All @@ -86,6 +87,7 @@ def __init__(
created options for validation/testing and predicting.
"""
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ class ParallelPlugin(TrainingTypePlugin, ABC):

def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
self.parallel_devices = parallel_devices
self.cluster_environment = cluster_environment

Expand Down
Loading