You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Lightning automates most of the training for you, the epoch and batch iterations, all you need to keep is the training step logic.
54
-
This should go into the :func:`~pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case):
57
+
*******************************
58
+
3. Configure the Training Logic
59
+
*******************************
60
+
61
+
Lightning automates the training loop for you and manages all of the associated components such as: epoch and batch tracking, optimizers and schedulers,
62
+
and metric reduction. As a user, you just need to define how your model behaves with a batch of training data within the
63
+
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method. When using Lightning, simply override the
64
+
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method which takes the current ``batch`` and the ``batch_idx``
65
+
as arguments. Optionally, it can take ``optimizer_idx`` if your LightningModule defines multiple optimizers within its
@@ -64,10 +76,17 @@ This should go into the :func:`~pytorch_lightning.core.LightningModule.training_
64
76
65
77
--------
66
78
67
-
4. Find the val loop "meat"
68
-
===========================
79
+
*********************************
80
+
4. Configure the Validation Logic
81
+
*********************************
82
+
83
+
Lightning also automates the validation loop for you and manages all of the associated components such as: epoch and batch tracking, and metrics reduction. As a user,
84
+
you just need to define how your model behaves with a batch of validation data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`
85
+
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` method which takes the current
86
+
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.
87
+
69
88
To add an (optional) validation loop add logic to the
70
-
:func:`~pytorch_lightning.core.LightningModule.validation_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
89
+
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
71
90
72
91
.. testcode::
73
92
@@ -76,38 +95,136 @@ To add an (optional) validation loop add logic to the
76
95
x, y = batch
77
96
y_hat = self(x)
78
97
val_loss = F.cross_entropy(y_hat, y)
79
-
return val_loss
98
+
self.log("val_loss", val_loss)
99
+
100
+
Additionally, you can run only the validation loop using :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate` method.
80
101
81
-
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for validation
102
+
.. code-block:: python
103
+
104
+
model = LitModel()
105
+
trainer.validate(model)
106
+
107
+
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for validation.
108
+
109
+
.. tip:: ``trainer.validate()`` loads the best checkpoint automatically by default if checkpointing was enabled during fitting.
82
110
83
111
--------
84
112
85
-
5. Find the test loop "meat"
86
-
============================
87
-
To add an (optional) test loop add logic to the
88
-
:func:`~pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
113
+
**************************
114
+
5. Configure Testing Logic
115
+
**************************
116
+
117
+
Lightning automates the testing loop for you and manages all the associated components, such as epoch and batch tracking, metrics reduction. As a user,
118
+
you just need to define how your model behaves with a batch of testing data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step`
119
+
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` method which takes the current
120
+
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.
89
121
90
122
.. testcode::
91
123
92
124
class LitModel(LightningModule):
93
125
def test_step(self, batch, batch_idx):
94
126
x, y = batch
95
127
y_hat = self(x)
96
-
loss = F.cross_entropy(y_hat, y)
97
-
return loss
128
+
test_loss = F.cross_entropy(y_hat, y)
129
+
self.log("test_loss", test_loss)
130
+
131
+
The test loop isn't used within :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`, therefore, you would need to explicitly call :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
132
+
133
+
.. code-block:: python
134
+
135
+
model = LitModel()
136
+
trainer.test(model)
137
+
138
+
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing.
139
+
140
+
.. tip:: ``trainer.test()`` loads the best checkpoint automatically by default if checkpointing is enabled.
141
+
142
+
--------
143
+
144
+
*****************************
145
+
6. Configure Prediction Logic
146
+
*****************************
147
+
148
+
Lightning automates the prediction loop for you and manages all of the associated components such as epoch and batch tracking. As a user,
149
+
you just need to define how your model behaves with a batch of data within the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step`
150
+
method. When using Lightning, simply override the :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` method which takes the current
151
+
``batch`` and the ``batch_idx`` as arguments. Optionally, it can take ``dataloader_idx`` if you configure multiple dataloaders.
152
+
If you don't override ``predict_step`` hook, it by default calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method on the batch.
153
+
154
+
.. testcode::
155
+
156
+
class LitModel(LightningModule):
157
+
def predict_step(self, batch, batch_idx):
158
+
x, y = batch
159
+
pred = self(x)
160
+
return pred
161
+
162
+
The predict loop will not be used until you call :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
163
+
164
+
.. code-block:: python
165
+
166
+
model = LitModel()
167
+
trainer.predict(model)
98
168
99
169
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing.
100
170
101
-
The test loop will not be used until you call.
171
+
.. tip:: ``trainer.predict()`` loads the best checkpoint automatically by default if checkpointing is enabled.
172
+
173
+
--------
174
+
175
+
******************************************
176
+
7. Remove any .cuda() or .to(device) Calls
177
+
******************************************
178
+
179
+
Your :doc:`LightningModule <../common/lightning_module>` can automatically run on any hardware!
180
+
181
+
If you have any explicit calls to ``.cuda()`` or ``.to(device)``, you can remove them since Lightning makes sure that the data coming from :class:`~torch.utils.data.DataLoader`
182
+
and all the :class:`~torch.nn.Module` instances initialized inside ``LightningModule.__init__`` are moved to the respective devices automatically.
If you still need to access the current device, you can use ``self.device`` anywhere in ``LightningModule`` except ``__init__`` method. You are initializing a
192
+
:class:`~torch.Tensor` within ``LightningModule.__init__`` method and want it to be moved to the device automatically you must :meth:`~torch.nn.Module.register_buffer`
193
+
to register it as a parameter.
104
194
105
-
trainer.test()
195
+
.. testcode::
106
196
107
-
.. tip:: ``.test()`` loads the best checkpoint automatically
197
+
class LitModel(LightningModule):
198
+
def training_step(self, batch, batch_idx):
199
+
z = torch.randn(4, 5, device=self.device)
200
+
...
108
201
109
202
--------
110
203
111
-
6. Remove any .cuda() or to.device() calls
112
-
==========================================
113
-
Your :doc:`lightning module <../common/lightning_module>` can automatically run on any hardware!
204
+
********************
205
+
8. Use your own data
206
+
********************
207
+
208
+
To use your DataLoaders, you can override the respective dataloader hooks in the :class:`~pytorch_lightning.core.lightning.LightningModule`:
209
+
210
+
.. testcode::
211
+
212
+
class LitModel(LightningModule):
213
+
def train_dataloader(self):
214
+
return DataLoader(...)
215
+
216
+
def val_dataloader(self):
217
+
return DataLoader(...)
218
+
219
+
def test_dataloader(self):
220
+
return DataLoader(...)
221
+
222
+
def predict_dataloader(self):
223
+
return DataLoader(...)
224
+
225
+
Alternatively, you can pass your dataloaders in one of the following ways:
226
+
227
+
* Pass in the dataloaders explictly inside ``trainer.fit/.validate/.test/.predict`` calls.
228
+
* Use a :ref:`LightningDataModule <datamodules>`.
229
+
230
+
Checkout :ref:`data` doc to understand data management within Lightning.
0 commit comments