Skip to content

Commit f811284

Browse files
rohitgr7tchatonethanwharris
authored
Update PT to PL conversion doc (#11397)
Co-authored-by: thomas chaton <[email protected]> Co-authored-by: Ethan Harris <[email protected]>
1 parent 9b0942d commit f811284

File tree

1 file changed

+149
-32
lines changed

1 file changed

+149
-32
lines changed

docs/source/starter/converting.rst

Lines changed: 149 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,20 @@
66

77
.. _converting:
88

9-
**************************************
9+
10+
######################################
1011
How to organize PyTorch into Lightning
11-
**************************************
12+
######################################
1213

13-
To enable your code to work with Lightning, here's how to organize PyTorch into Lightning
14+
To enable your code to work with Lightning, here's how to organize PyTorch into Lightning:
1415

1516
--------
1617

17-
1. Move your computational code
18-
===============================
19-
Move the model architecture and forward pass to your :doc:`lightning module <../common/lightning_module>`.
18+
*******************************
19+
1. Move your Computational Code
20+
*******************************
21+
22+
Move the model architecture and forward pass to your :class:`~pytorch_lightning.core.lightning.LightningModule`.
2023

2124
.. testcode::
2225

@@ -35,23 +38,32 @@ Move the model architecture and forward pass to your :doc:`lightning module <../
3538

3639
--------
3740

38-
2. Move the optimizer(s) and schedulers
39-
=======================================
40-
Move your optimizers to the :func:`~pytorch_lightning.core.LightningModule.configure_optimizers` hook.
41+
********************************************
42+
2. Move the Optimizer(s) and LR Scheduler(s)
43+
********************************************
44+
45+
Move your optimizers to the :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` hook.
4146

4247
.. testcode::
4348

4449
class LitModel(LightningModule):
4550
def configure_optimizers(self):
4651
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
47-
return optimizer
52+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
53+
return [optimizer], [lr_scheduler]
4854

4955
--------
5056

51-
3. Find the train loop "meat"
52-
=============================
53-
Lightning automates most of the training for you, the epoch and batch iterations, all you need to keep is the training step logic.
54-
This should go into the :func:`~pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case):
57+
*******************************
58+
3. Configure the Training Logic
59+
*******************************
60+
61+
Lightning automates the training loop for you and manages all of the associated components such as: epoch and batch tracking, optimizers and schedulers,
62+
and metric reduction. As a user, you just need to define how your model behaves with a batch of training data within the
63+
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method. When using Lightning, simply override the
64+
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method which takes the current ``batch`` and the ``batch_idx``
65+
as arguments. Optionally, it can take ``optimizer_idx`` if your LightningModule defines multiple optimizers within its
66+
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` hook.
5567

5668
.. testcode::
5769

@@ -64,10 +76,17 @@ This should go into the :func:`~pytorch_lightning.core.LightningModule.training_
6476

6577
--------
6678

67-
4. Find the val loop "meat"
68-
===========================
79+
*********************************
80+
4. Configure the Validation Logic
81+
*********************************
82+
83+
Lightning also automates the validation loop for you and manages all of the associated components such as: epoch and batch tracking, and metrics reduction. As a user,
84+
you just need to define how your model behaves with a batch of validation data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`
85+
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` method which takes the current
86+
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.
87+
6988
To add an (optional) validation loop add logic to the
70-
:func:`~pytorch_lightning.core.LightningModule.validation_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
89+
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
7190

7291
.. testcode::
7392

@@ -76,38 +95,136 @@ To add an (optional) validation loop add logic to the
7695
x, y = batch
7796
y_hat = self(x)
7897
val_loss = F.cross_entropy(y_hat, y)
79-
return val_loss
98+
self.log("val_loss", val_loss)
99+
100+
Additionally, you can run only the validation loop using :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate` method.
80101

81-
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for validation
102+
.. code-block:: python
103+
104+
model = LitModel()
105+
trainer.validate(model)
106+
107+
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for validation.
108+
109+
.. tip:: ``trainer.validate()`` loads the best checkpoint automatically by default if checkpointing was enabled during fitting.
82110

83111
--------
84112

85-
5. Find the test loop "meat"
86-
============================
87-
To add an (optional) test loop add logic to the
88-
:func:`~pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
113+
**************************
114+
5. Configure Testing Logic
115+
**************************
116+
117+
Lightning automates the testing loop for you and manages all the associated components, such as epoch and batch tracking, metrics reduction. As a user,
118+
you just need to define how your model behaves with a batch of testing data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step`
119+
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` method which takes the current
120+
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.
89121

90122
.. testcode::
91123

92124
class LitModel(LightningModule):
93125
def test_step(self, batch, batch_idx):
94126
x, y = batch
95127
y_hat = self(x)
96-
loss = F.cross_entropy(y_hat, y)
97-
return loss
128+
test_loss = F.cross_entropy(y_hat, y)
129+
self.log("test_loss", test_loss)
130+
131+
The test loop isn't used within :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`, therefore, you would need to explicitly call :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
132+
133+
.. code-block:: python
134+
135+
model = LitModel()
136+
trainer.test(model)
137+
138+
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing.
139+
140+
.. tip:: ``trainer.test()`` loads the best checkpoint automatically by default if checkpointing is enabled.
141+
142+
--------
143+
144+
*****************************
145+
6. Configure Prediction Logic
146+
*****************************
147+
148+
Lightning automates the prediction loop for you and manages all of the associated components such as epoch and batch tracking. As a user,
149+
you just need to define how your model behaves with a batch of data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step`
150+
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method which takes the current
151+
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.
152+
If you don't override ``predict_step`` hook, it by default calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method on the batch.
153+
154+
.. testcode::
155+
156+
class LitModel(LightningModule):
157+
def predict_step(self, batch, batch_idx):
158+
x, y = batch
159+
pred = self(x)
160+
return pred
161+
162+
The predict loop will not be used until you call :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
163+
164+
.. code-block:: python
165+
166+
model = LitModel()
167+
trainer.predict(model)
98168
99169
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing.
100170

101-
The test loop will not be used until you call.
171+
.. tip:: ``trainer.predict()`` loads the best checkpoint automatically by default if checkpointing is enabled.
172+
173+
--------
174+
175+
******************************************
176+
7. Remove any .cuda() or .to(device) Calls
177+
******************************************
178+
179+
Your :doc:`LightningModule <../common/lightning_module>` can automatically run on any hardware!
180+
181+
If you have any explicit calls to ``.cuda()`` or ``.to(device)``, you can remove them since Lightning makes sure that the data coming from :class:`~torch.utils.data.DataLoader`
182+
and all the :class:`~torch.nn.Module` instances initialized inside ``LightningModule.__init__`` are moved to the respective devices automatically.
183+
184+
.. testcode::
185+
186+
class LitModel(LightningModule):
187+
def __init__(self):
188+
super().__init__()
189+
self.register_buffer("running_mean", torch.zeros(num_features))
102190

103-
.. code-block::
191+
If you still need to access the current device, you can use ``self.device`` anywhere in ``LightningModule`` except ``__init__`` method. You are initializing a
192+
:class:`~torch.Tensor` within ``LightningModule.__init__`` method and want it to be moved to the device automatically you must :meth:`~torch.nn.Module.register_buffer`
193+
to register it as a parameter.
104194

105-
trainer.test()
195+
.. testcode::
106196

107-
.. tip:: ``.test()`` loads the best checkpoint automatically
197+
class LitModel(LightningModule):
198+
def training_step(self, batch, batch_idx):
199+
z = torch.randn(4, 5, device=self.device)
200+
...
108201

109202
--------
110203

111-
6. Remove any .cuda() or to.device() calls
112-
==========================================
113-
Your :doc:`lightning module <../common/lightning_module>` can automatically run on any hardware!
204+
********************
205+
8. Use your own data
206+
********************
207+
208+
To use your DataLoaders, you can override the respective dataloader hooks in the :class:`~pytorch_lightning.core.lightning.LightningModule`:
209+
210+
.. testcode::
211+
212+
class LitModel(LightningModule):
213+
def train_dataloader(self):
214+
return DataLoader(...)
215+
216+
def val_dataloader(self):
217+
return DataLoader(...)
218+
219+
def test_dataloader(self):
220+
return DataLoader(...)
221+
222+
def predict_dataloader(self):
223+
return DataLoader(...)
224+
225+
Alternatively, you can pass your dataloaders in one of the following ways:
226+
227+
* Pass in the dataloaders explictly inside ``trainer.fit/.validate/.test/.predict`` calls.
228+
* Use a :ref:`LightningDataModule <datamodules>`.
229+
230+
Checkout :ref:`data` doc to understand data management within Lightning.

0 commit comments

Comments
 (0)