@@ -5,10 +5,14 @@ Lightning supports running on TPUs. At this moment, TPUs are only available
55on Google Cloud (GCP). For more information on TPUs
66`watch this video <https://www.youtube.com/watch?v=kPMpmcl_Pyw >`_.
77
8+ ---------------
9+
810Live demo
911----------
1012Check out this `Google Colab <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3 >`_ to see how to train MNIST on TPUs.
1113
14+ ---------------
15+
1216TPU Terminology
1317---------------
1418A TPU is a Tensor processing unit. Each TPU has 8 cores where each
@@ -19,30 +23,34 @@ A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores!
1923You can request a full pod from Google cloud or a "slice" which gives you
2024some subset of those 2048 cores.
2125
26+ ---------------
27+
2228How to access TPUs
2329-------------------
2430To access TPUs there are two main ways.
2531
26321. Using google colab.
27332. Using Google Cloud (GCP).
2834
35+ ---------------
36+
2937Colab TPUs
3038-----------
3139Colab is like a jupyter notebook with a free GPU or TPU
3240hosted on GCP.
3341
3442To get a TPU on colab, follow these steps:
3543
36- 1. Go to https://colab.research.google.com/.
44+ 1. Go to https://colab.research.google.com/.
3745
38- 2. Click "new notebook" (bottom right of pop-up).
46+ 2. Click "new notebook" (bottom right of pop-up).
3947
40- 3. Click runtime > change runtime settings. Select Python 3,
41- and hardware accelerator "TPU". This will give you a TPU with 8 cores.
48+ 3. Click runtime > change runtime settings. Select Python 3,
49+ and hardware accelerator "TPU". This will give you a TPU with 8 cores.
4250
43- 4. Next, insert this code into the first cell and execute. This
44- will install the xla library that interfaces between PyTorch and
45- the TPU.
51+ 4. Next, insert this code into the first cell and execute. This
52+ will install the xla library that interfaces between PyTorch and
53+ the TPU.
4654
4755.. code-block :: python
4856
@@ -86,16 +94,28 @@ the TPU.
8694 ! pip install " $TORCHVISION_WHEEL"
8795 ! sudo apt- get install libomp5
8896 update.join()
89- 5. Once the above is done, install PyTorch Lightning (v 0.7.0+).
97+
98+ 5 . Once the above is done, install PyTorch Lightning (v 0.7 .0+ ).
9099
91100 .. code-block ::
92101
93102 ! pip install pytorch-lightning
94103
95104 6. Then set up your LightningModule as normal.
96105
97- 7. TPUs require a DistributedSampler. That means you should change your
98- train_dataloader (and val, train) code as follows.
106+ ---------------
107+
108+ DistributedSamplers
109+ -------------------
110+ Lightning automatically inserts the correct samplers - no need to do this yourself!
111+
112+ Usually, with TPUs (and DDP), you would need to define a DistributedSampler to move the right
113+ chunk of data to the appropriate TPU. As mentioned, this is not needed in Lightning
114+
115+ .. note :: Don't add distributedSamplers. Lightning does this automatically
116+
117+ If for some reason you still need to, this is how to construct the sampler
118+ for TPU use
99119
100120.. code-block :: python
101121
@@ -140,6 +160,15 @@ train_dataloader (and val, train) code as follows.
140160
141161 That's it! Your model will train on all 8 TPU cores.
142162
163+ ---------------
164+
165+ Distributed Backend with TPU
166+ ----------------------------
167+ The ```distributed_backend` `` option used for GPUs does not apply to TPUs.
168+ TPUs work in DDP mode by default (distributing over each core)
169+
170+ ---------------
171+
143172TPU Pod
144173--------
145174To train on more than 8 cores, your code actually doesn't change!
@@ -152,6 +181,8 @@ All you need to do is submit the following command:
152181 --conda-env=torch-xla-nightly
153182 -- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data
154183
184+ ---------------
185+
15518616 bit precision
156187-----------------
157188Lightning also supports training in 16-bit precision with TPUs.
@@ -168,6 +199,7 @@ set the 16-bit flag.
168199
169200 Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format >`_.
170201
202+ ---------------
171203
172204About XLA
173205----------
0 commit comments