Skip to content

Commit b8ac176

Browse files
Docs: fix mistakes in New Project docs (#10137)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent 85eb17c commit b8ac176

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed

docs/source/starter/new-project.rst

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ Under the hood a LightningModule is still just a :class:`torch.nn.Module` that g
134134
- The Train loop
135135
- The Validation loop
136136
- The Test loop
137+
- The Prediction loop
137138
- The Model or system of Models
138139
- The Optimizer
139140

@@ -181,7 +182,7 @@ More details in :doc:`lightning module <../common/lightning_module>` docs.
181182
Step 2: Fit with Lightning Trainer
182183
**********************************
183184

184-
First, define the data however you want. Lightning just needs a :class:`~torch.utils.data.DataLoader` for the train/val/test splits.
185+
First, define the data however you want. Lightning just needs a :class:`~torch.utils.data.DataLoader` for the train/val/test/predict splits.
185186

186187
.. code-block:: python
187188
@@ -258,7 +259,8 @@ Turn off automatic optimization and you control the train loop!
258259
259260
260261
def training_step(self, batch, batch_idx):
261-
# access your optimizers with use_pl_optimizer=False. Default is True
262+
# access your optimizers with use_pl_optimizer=False. Default is True,
263+
# setting use_pl_optimizer=True will maintain plugin/precision support
262264
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
263265
264266
loss_a = self.generator(batch)
@@ -321,7 +323,7 @@ You can also add a forward method to do predictions however you want.
321323

322324

323325
autoencoder = LitAutoEncoder()
324-
autoencoder = autoencoder(torch.rand(1, 28 * 28))
326+
embedding = autoencoder(torch.rand(1, 28 * 28))
325327

326328

327329
.. code-block:: python
@@ -371,9 +373,9 @@ a forward method or trace only the sub-models you need.
371373
372374
--------------------
373375

374-
Using CPUs/GPUs/TPUs
375-
====================
376-
It's trivial to use CPUs, GPUs or TPUs in Lightning. There's **NO NEED** to change your code, simply change the :class:`~pytorch_lightning.trainer.Trainer` options.
376+
Using CPUs/GPUs/TPUs/IPUs
377+
=========================
378+
It's trivial to use CPUs, GPUs, TPUs or IPUs in Lightning. There's **NO NEED** to change your code, simply change the :class:`~pytorch_lightning.trainer.Trainer` options.
377379

378380
.. testcode::
379381

@@ -423,6 +425,11 @@ Without changing a SINGLE line of your code, you can now do the following with t
423425
# using only half the training data and checking validation every quarter of a training epoch
424426
trainer = pl.Trainer(tpu_cores=8, precision=16, limit_train_batches=0.5, val_check_interval=0.25)
425427
428+
.. code-block:: python
429+
430+
# Train on IPUs
431+
trainer = pl.Trainer(ipus=8)
432+
426433
-----------
427434

428435
Checkpoints
@@ -449,7 +456,7 @@ If you prefer to do it manually, here's the equivalent
449456

450457
Data flow
451458
=========
452-
Each loop (training, validation, test) has three hooks you can implement:
459+
Each loop (training, validation, test, predict) has three hooks you can implement:
453460

454461
- x_step
455462
- x_step_end
@@ -474,8 +481,8 @@ The equivalent in Lightning is:
474481
return prediction
475482
476483
477-
def training_epoch_end(self, training_step_outputs):
478-
for prediction in predictions:
484+
def training_epoch_end(self, outs):
485+
for out in outs:
479486
...
480487
481488
In the event that you use DP or DDP2 distributed modes (ie: split a batch across GPUs),
@@ -508,9 +515,9 @@ The lightning equivalent is:
508515
def training_step_end(self, losses):
509516
gpu_0_loss = losses[0]
510517
gpu_1_loss = losses[1]
511-
return (gpu_0_loss + gpu_1_loss) * 1 / 2
518+
return (gpu_0_loss + gpu_1_loss) / 2
512519
513-
.. tip:: The validation and test loops have the same structure.
520+
.. tip:: The validation, test and prediction loops have the same structure.
514521

515522
-----------------
516523

@@ -648,8 +655,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
648655
if stage in (None, "fit"):
649656
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
650657
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
651-
if stage == (None, "test"):
658+
if stage == "test":
652659
self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
660+
if stage == "predict":
661+
self.mnist_predict = MNIST(os.getcwd(), train=False, transform=transform)
653662

654663
# return the dataloader for each split
655664
def train_dataloader(self):
@@ -664,6 +673,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
664673
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
665674
return mnist_test
666675

676+
def predict_dataloader(self):
677+
mnist_predict = DataLoader(self.mnist_predict, batch_size=self.batch_size)
678+
return mnist_predict
679+
667680
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` is designed to enable sharing and reusing data splits
668681
and transforms across different projects. It encapsulates all the steps needed to process data: downloading,
669682
tokenizing, processing etc.
@@ -681,11 +694,17 @@ the :class:`~pytorch_lightning.trainer.Trainer`:
681694
682695
# train
683696
trainer = pl.Trainer()
684-
trainer.fit(model, dm)
697+
trainer.fit(model, datamodule=dm)
698+
699+
# validate
700+
trainer.validate(datamodule=dm)
685701
686702
# test
687703
trainer.test(datamodule=dm)
688704
705+
# predict
706+
predictions = trainer.predict(datamodule=dm)
707+
689708
DataModules are specifically useful for building models based on data. Read more on :doc:`datamodules <../extensions/datamodules>`.
690709

691710
------
@@ -701,15 +720,18 @@ Lightning has many tools for debugging. Here is an example of just a few of them
701720

702721
.. testcode::
703722

704-
# Automatically overfit the sane batch of your model for a sanity test
723+
# Automatically overfit the same batch of your model for a sanity test
705724
trainer = Trainer(overfit_batches=1)
706725

707726
.. testcode::
708727

709-
# unit test all the code- hits every line of your code once to see if you have bugs,
728+
# unit test all the code - hits every line of your code once to see if you have bugs,
710729
# instead of waiting hours to crash on validation
711730
trainer = Trainer(fast_dev_run=True)
712731

732+
# unit test all the code - hits every line of your code with 4 batches
733+
trainer = Trainer(fast_dev_run=4)
734+
713735
.. testcode::
714736

715737
# train only 20% of an epoch
@@ -739,7 +761,7 @@ Once you define and train your first Lightning model, you might want to try othe
739761
- :doc:`Automatically find a good learning rate <../advanced/lr_finder>`
740762
- :ref:`Load checkpoints directly from S3 <common/weights_loading:Checkpoint Loading>`
741763
- :doc:`Scale to massive compute clusters <../clouds/cluster>`
742-
- :doc:`Use multiple dataloaders per train/val/test loop <../guides/data>`
764+
- :doc:`Use multiple dataloaders per train/val/test/predict loop <../guides/data>`
743765
- :ref:`Use multiple optimizers to do reinforcement learning or even GANs <common/optimizers:Use multiple optimizers (like GANs)>`
744766

745767
Or read our :doc:`Guide <../starter/introduction_guide>` to learn more!

0 commit comments

Comments
 (0)