Skip to content

Commit 1ab5c46

Browse files
williamFalcontullie
authored andcommitted
* fixed docs * fixed docs * fixed docs
1 parent 1a30e4a commit 1ab5c46

File tree

3 files changed

+70
-12
lines changed

3 files changed

+70
-12
lines changed

docs/source/tpu.rst

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@ Lightning supports running on TPUs. At this moment, TPUs are only available
55
on Google Cloud (GCP). For more information on TPUs
66
`watch this video <https://www.youtube.com/watch?v=kPMpmcl_Pyw>`_.
77

8+
---------------
9+
810
Live demo
911
----------
1012
Check out this `Google Colab <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3>`_ to see how to train MNIST on TPUs.
1113

14+
---------------
15+
1216
TPU Terminology
1317
---------------
1418
A TPU is a Tensor processing unit. Each TPU has 8 cores where each
@@ -19,30 +23,34 @@ A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores!
1923
You can request a full pod from Google cloud or a "slice" which gives you
2024
some subset of those 2048 cores.
2125

26+
---------------
27+
2228
How to access TPUs
2329
-------------------
2430
To access TPUs there are two main ways.
2531

2632
1. Using google colab.
2733
2. Using Google Cloud (GCP).
2834

35+
---------------
36+
2937
Colab TPUs
3038
-----------
3139
Colab is like a jupyter notebook with a free GPU or TPU
3240
hosted on GCP.
3341

3442
To get a TPU on colab, follow these steps:
3543

36-
1. Go to https://colab.research.google.com/.
44+
1. Go to https://colab.research.google.com/.
3745

38-
2. Click "new notebook" (bottom right of pop-up).
46+
2. Click "new notebook" (bottom right of pop-up).
3947

40-
3. Click runtime > change runtime settings. Select Python 3,
41-
and hardware accelerator "TPU". This will give you a TPU with 8 cores.
48+
3. Click runtime > change runtime settings. Select Python 3,
49+
and hardware accelerator "TPU". This will give you a TPU with 8 cores.
4250

43-
4. Next, insert this code into the first cell and execute. This
44-
will install the xla library that interfaces between PyTorch and
45-
the TPU.
51+
4. Next, insert this code into the first cell and execute. This
52+
will install the xla library that interfaces between PyTorch and
53+
the TPU.
4654

4755
.. code-block:: python
4856
@@ -86,16 +94,28 @@ the TPU.
8694
!pip install "$TORCHVISION_WHEEL"
8795
!sudo apt-get install libomp5
8896
update.join()
89-
5. Once the above is done, install PyTorch Lightning (v 0.7.0+).
97+
98+
5. Once the above is done, install PyTorch Lightning (v 0.7.0+).
9099
91100
.. code-block::
92101
93102
! pip install pytorch-lightning
94103
95104
6. Then set up your LightningModule as normal.
96105

97-
7. TPUs require a DistributedSampler. That means you should change your
98-
train_dataloader (and val, train) code as follows.
106+
---------------
107+
108+
DistributedSamplers
109+
-------------------
110+
Lightning automatically inserts the correct samplers - no need to do this yourself!
111+
112+
Usually, with TPUs (and DDP), you would need to define a DistributedSampler to move the right
113+
chunk of data to the appropriate TPU. As mentioned, this is not needed in Lightning
114+
115+
.. note:: Don't add distributedSamplers. Lightning does this automatically
116+
117+
If for some reason you still need to, this is how to construct the sampler
118+
for TPU use
99119

100120
.. code-block:: python
101121
@@ -140,6 +160,15 @@ train_dataloader (and val, train) code as follows.
140160
141161
That's it! Your model will train on all 8 TPU cores.
142162

163+
---------------
164+
165+
Distributed Backend with TPU
166+
----------------------------
167+
The ```distributed_backend``` option used for GPUs does not apply to TPUs.
168+
TPUs work in DDP mode by default (distributing over each core)
169+
170+
---------------
171+
143172
TPU Pod
144173
--------
145174
To train on more than 8 cores, your code actually doesn't change!
@@ -152,6 +181,8 @@ All you need to do is submit the following command:
152181
--conda-env=torch-xla-nightly
153182
-- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data
154183
184+
---------------
185+
155186
16 bit precision
156187
-----------------
157188
Lightning also supports training in 16-bit precision with TPUs.
@@ -168,6 +199,7 @@ set the 16-bit flag.
168199
169200
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.
170201

202+
---------------
171203

172204
About XLA
173205
----------

pytorch_lightning/core/__init__.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,10 @@ def configure_optimizers(self):
107107
-----------------------
108108
109109
The general pattern is that each loop (training, validation, test loop)
110-
has 2 methods:
110+
has 3 methods:
111111
112112
- ``` ___step ```
113+
- ``` ___step_end ```
113114
- ``` ___epoch_end```
114115
115116
To show how lightning calls these, let's use the validation loop as an example
@@ -126,6 +127,28 @@ def configure_optimizers(self):
126127
# like calculate validation set accuracy or loss
127128
validation_epoch_end(val_outs)
128129
130+
if we use dp or ddp2 mode, we can also define the ```XXX_step_end``` method to operate
131+
on all parts of the batch
132+
133+
.. code-block:: python
134+
135+
val_outs = []
136+
for val_batch in val_data:
137+
batches = split_batch(val_batch)
138+
dp_outs = []
139+
for sub_batch in batches:
140+
dp_out = validation_step(sub_batch)
141+
dp_outs.append(dp_out)
142+
143+
out = validation_step_end(dp_outs)
144+
val_outs.append(out)
145+
146+
# do something with the outputs for all batches
147+
# like calculate validation set accuracy or loss
148+
validation_epoch_end(val_outs)
149+
150+
.. note:: ```training_step_end``` is not available yet but coming in the next release.
151+
129152
Add validation loop
130153
^^^^^^^^^^^^^^^^^^^
131154

pytorch_lightning/trainer/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def forward(self, x):
146146
callbacks
147147
^^^^^^^^^
148148
149-
Add a list of user defined callbacks.
149+
Add a list of user defined callbacks. These callbacks DO NOT replace the explicit callbacks
150+
(loggers, EarlyStopping or ModelCheckpoint).
150151
151152
.. note:: Only user defined callbacks (ie: Not EarlyStopping or ModelCheckpoint)
152153
@@ -239,6 +240,8 @@ def on_train_end(self):
239240
# ddp2 = DistributedDataParallel + dp
240241
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
241242
243+
.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core)
244+
242245
early_stop_callback
243246
^^^^^^^^^^^^^^^^^^^
244247

0 commit comments

Comments
 (0)