Skip to content

Commit f6361f7

Browse files
williamFalconSeanNaren
authored andcommitted
Accelerator docs (#4583)
* accelerator docs * accelerator docs (cherry picked from commit ee35907)
1 parent 65cfa17 commit f6361f7

14 files changed

+287
-4
lines changed

docs/source/accelerators.rst

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
############
2+
Accelerators
3+
############
4+
Accelerators connect a Lightning Trainer to arbitrary accelerators (CPUs, GPUs, TPUs, etc). Accelerators
5+
also manage distributed accelerators (like DP, DDP, HPC cluster).
6+
7+
Accelerators can also be configured to run on arbitrary clusters using Plugins or to link up to arbitrary
8+
computational strategies like 16-bit precision via AMP and Apex.
9+
10+
----------
11+
12+
******************************
13+
Implement a custom accelerator
14+
******************************
15+
To link up arbitrary hardware, implement your own Accelerator subclass
16+
17+
.. code-block:: python
18+
19+
from pytorch_lightning.accelerators.accelerator import Accelerator
20+
21+
class MyAccelerator(Accelerator):
22+
def __init__(self, trainer, cluster_environment=None):
23+
super().__init__(trainer, cluster_environment)
24+
self.nickname = 'my_accelator'
25+
26+
def setup(self):
27+
# find local rank, etc, custom things to implement
28+
29+
def train(self):
30+
# implement what happens during training
31+
32+
def training_step(self):
33+
# implement how to do a training_step on this accelerator
34+
35+
def validation_step(self):
36+
# implement how to do a validation_step on this accelerator
37+
38+
def test_step(self):
39+
# implement how to do a test_step on this accelerator
40+
41+
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
42+
# implement how to do a backward pass with this accelerator
43+
44+
def barrier(self, name: Optional[str] = None):
45+
# implement this accelerator's barrier
46+
47+
def broadcast(self, obj, src=0):
48+
# implement this accelerator's broadcast function
49+
50+
def sync_tensor(self,
51+
tensor: Union[torch.Tensor],
52+
group: Optional[Any] = None,
53+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
54+
# implement how to sync tensors when reducing metrics across accelerators
55+
56+
********
57+
Examples
58+
********
59+
The following examples illustrate customizing accelerators.
60+
61+
Example 1: Arbitrary HPC cluster
62+
================================
63+
To link any accelerator with an arbitrary cluster (SLURM, Condor, etc), pass in a Cluster Plugin which will be passed
64+
into any accelerator.
65+
66+
First, implement your own ClusterEnvironment. Here is the torch elastic implementation.
67+
68+
.. code-block:: python
69+
70+
import os
71+
from pytorch_lightning import _logger as log
72+
from pytorch_lightning.utilities import rank_zero_warn
73+
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment
74+
75+
class TorchElasticEnvironment(ClusterEnvironment):
76+
77+
def __init__(self):
78+
super().__init__()
79+
80+
def master_address(self):
81+
if "MASTER_ADDR" not in os.environ:
82+
rank_zero_warn(
83+
"MASTER_ADDR environment variable is not defined. Set as localhost"
84+
)
85+
os.environ["MASTER_ADDR"] = "127.0.0.1"
86+
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
87+
master_address = os.environ.get('MASTER_ADDR')
88+
return master_address
89+
90+
def master_port(self):
91+
if "MASTER_PORT" not in os.environ:
92+
rank_zero_warn(
93+
"MASTER_PORT environment variable is not defined. Set as 12910"
94+
)
95+
os.environ["MASTER_PORT"] = "12910"
96+
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
97+
98+
port = os.environ.get('MASTER_PORT')
99+
return port
100+
101+
def world_size(self):
102+
return os.environ.get('WORLD_SIZE')
103+
104+
def local_rank(self):
105+
return int(os.environ['LOCAL_RANK'])
106+
107+
Now, pass it into the trainer which will use Torch Elastic across your accelerator of choice.
108+
109+
.. code-block:: python
110+
111+
cluster = TorchElasticEnvironment()
112+
accelerator = MyAccelerator()
113+
trainer = Trainer(plugins=[cluster], accelerator=MyAccelerator())
114+
115+
In this example, MyAccelerator can define arbitrary hardware (like IPUs or TPUs) and links it to an arbitrary
116+
compute cluster.
117+
118+
------------
119+
120+
**********************
121+
Available Accelerators
122+
**********************
123+
124+
CPU Accelerator
125+
===============
126+
127+
.. autoclass:: pytorch_lightning.accelerators.cpu_accelerator.CPUAccelerator
128+
:noindex:
129+
130+
DDP Accelerator
131+
===============
132+
133+
.. autoclass:: pytorch_lightning.accelerators.ddp_accelerator.DDPAccelerator
134+
:noindex:
135+
136+
DDP2 Accelerator
137+
================
138+
139+
.. autoclass:: pytorch_lightning.accelerators.ddp2_accelerator.DDP2Accelerator
140+
:noindex:
141+
142+
DDP CPU HPC Accelerator
143+
=======================
144+
145+
.. autoclass:: pytorch_lightning.accelerators.ddp_cpu_hpc_accelerator.DDPCPUHPCAccelerator
146+
:noindex:
147+
148+
DDP CPU Spawn Accelerator
149+
=========================
150+
151+
.. autoclass:: pytorch_lightning.accelerators.ddp_cpu_spawn_accelerator.DDPCPUSpawnAccelerator
152+
:noindex:
153+
154+
DDP HPC Accelerator
155+
===================
156+
157+
.. autoclass:: pytorch_lightning.accelerators.ddp_hpc_accelerator.DDPHPCAccelerator
158+
:noindex:
159+
160+
DDP Spawn Accelerator
161+
=====================
162+
163+
.. autoclass:: pytorch_lightning.accelerators.ddp_spawn_accelerator.DDPSpawnAccelerator
164+
:noindex:
165+
166+
GPU Accelerator
167+
===============
168+
169+
.. autoclass:: pytorch_lightning.accelerators.gpu_accelerator.GPUAccelerator
170+
:noindex:
171+
172+
Horovod Accelerator
173+
===================
174+
175+
.. autoclass:: pytorch_lightning.accelerators.horovod_accelerator.HorovodAccelerator
176+
:noindex:
177+
178+
TPU Accelerator
179+
===============
180+
181+
.. autoclass:: pytorch_lightning.accelerators.tpu_accelerator.TPUAccelerator
182+
:noindex:

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ PyTorch Lightning Documentation
3939
:name: docs
4040
:caption: Optional extensions
4141

42+
accelerators
4243
callbacks
4344
datamodules
4445
logging

pytorch_lightning/accelerators/accelerator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,13 @@ def sync_tensor(self,
221221
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
222222
"""
223223
Function to reduce a tensor from several distributed processes to one aggregated tensor.
224+
224225
Args:
225226
tensor: the tensor to sync and reduce
226227
group: the process group to gather results from. Defaults to all processes (world)
227228
reduce_op: the reduction operation. Defaults to sum.
228229
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
230+
229231
Return:
230232
reduced value
231233
"""

pytorch_lightning/accelerators/cpu_accelerator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@
2121
class CPUAccelerator(Accelerator):
2222

2323
def __init__(self, trainer, cluster_environment=None):
24+
"""
25+
Runs training on CPU
26+
27+
Example::
28+
29+
# default
30+
trainer = Trainer(accelerator=CPUAccelerator())
31+
32+
"""
2433
super().__init__(trainer, cluster_environment)
2534
self.nickname = None
2635

pytorch_lightning/accelerators/ddp2_accelerator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@
3838
class DDP2Accelerator(Accelerator):
3939

4040
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
41+
"""
42+
Runs training using DDP2 strategy on a cluster
43+
44+
Example::
45+
46+
# default
47+
trainer = Trainer(accelerator=DDP2Accelerator())
48+
49+
"""
4150
super().__init__(trainer, cluster_environment, ddp_plugin)
4251
self.task_idx = None
4352
self.dist = LightningDistributed()

pytorch_lightning/accelerators/ddp_accelerator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@
4747
class DDPAccelerator(Accelerator):
4848

4949
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
50+
"""
51+
Runs training using DDP strategy on a single machine (manually, not via cluster start)
52+
53+
Example::
54+
55+
# default
56+
trainer = Trainer(accelerator=DDPAccelerator())
57+
58+
"""
5059
super().__init__(trainer, cluster_environment, ddp_plugin)
5160
self.task_idx = None
5261
self._has_spawned_children = False
@@ -304,4 +313,7 @@ def sync_tensor(self,
304313
tensor: Union[torch.Tensor],
305314
group: Optional[Any] = None,
306315
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
316+
"""
317+
318+
"""
307319
return sync_ddp_if_available(tensor, group, reduce_op)

pytorch_lightning/accelerators/ddp_cpu_hpc_accelerator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@
2626
class DDPCPUHPCAccelerator(DDPHPCAccelerator):
2727

2828
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
29+
"""
30+
Runs training using DDP (with CPUs) strategy on a cluster
31+
32+
Example::
33+
34+
# default
35+
trainer = Trainer(accelerator=DDPCPUHPCAccelerator())
36+
37+
"""
2938
super().__init__(trainer, cluster_environment, ddp_plugin)
3039
self.nickname = 'ddp_cpu'
3140

pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@
4040
class DDPCPUSpawnAccelerator(Accelerator):
4141

4242
def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
43+
"""
44+
Runs training using DDP (on a single machine or manually on multiple machines), using mp.spawn
45+
46+
Example::
47+
48+
# default
49+
trainer = Trainer(accelerator=DDPCPUSpawnAccelerator())
50+
51+
"""
4352
super().__init__(trainer, cluster_environment, ddp_plugin)
4453
self.mp_queue = None
4554
self.nprocs = nprocs

pytorch_lightning/accelerators/ddp_hpc_accelerator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@
3939
class DDPHPCAccelerator(Accelerator):
4040

4141
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
42+
"""
43+
Runs training using DDP on an HPC cluster
44+
45+
Example::
46+
47+
# default
48+
trainer = Trainer(accelerator=DDPHPCAccelerator())
49+
50+
"""
4251
super().__init__(trainer, cluster_environment, ddp_plugin)
4352
self.task_idx = None
4453
self._has_spawned_children = False

pytorch_lightning/accelerators/ddp_spawn_accelerator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@
4343
class DDPSpawnAccelerator(Accelerator):
4444

4545
def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
46+
"""
47+
Runs training using DDP using mp.spawn via manual launch (not cluster launch)
48+
49+
Example::
50+
51+
# default
52+
trainer = Trainer(accelerator=DDPSpawnAccelerator())
53+
54+
"""
4655
super().__init__(trainer, cluster_environment, ddp_plugin)
4756
self.mp_queue = None
4857
self.nprocs = nprocs

0 commit comments

Comments
 (0)