@@ -134,6 +134,7 @@ Under the hood a LightningModule is still just a :class:`torch.nn.Module` that g
134134- The Train loop
135135- The Validation loop
136136- The Test loop
137+ - The Prediction loop
137138- The Model or system of Models
138139- The Optimizer
139140
@@ -181,7 +182,7 @@ More details in :doc:`lightning module <../common/lightning_module>` docs.
181182Step 2: Fit with Lightning Trainer
182183**********************************
183184
184- First, define the data however you want. Lightning just needs a :class: `~torch.utils.data.DataLoader ` for the train/val/test splits.
185+ First, define the data however you want. Lightning just needs a :class: `~torch.utils.data.DataLoader ` for the train/val/test/predict splits.
185186
186187.. code-block :: python
187188
@@ -258,7 +259,8 @@ Turn off automatic optimization and you control the train loop!
258259
259260
260261 def training_step (self , batch , batch_idx ):
261- # access your optimizers with use_pl_optimizer=False. Default is True
262+ # access your optimizers with use_pl_optimizer=False. Default is True,
263+ # setting use_pl_optimizer=True will maintain plugin/precision support
262264 opt_a, opt_b = self .optimizers(use_pl_optimizer = True )
263265
264266 loss_a = self .generator(batch)
@@ -321,7 +323,7 @@ You can also add a forward method to do predictions however you want.
321323
322324
323325 autoencoder = LitAutoEncoder()
324- autoencoder = autoencoder(torch.rand(1, 28 * 28))
326+ embedding = autoencoder(torch.rand(1, 28 * 28))
325327
326328
327329.. code-block :: python
@@ -371,9 +373,9 @@ a forward method or trace only the sub-models you need.
371373
372374--------------------
373375
374- Using CPUs/GPUs/TPUs
375- ====================
376- It's trivial to use CPUs, GPUs or TPUs in Lightning. There's **NO NEED ** to change your code, simply change the :class: `~pytorch_lightning.trainer.Trainer ` options.
376+ Using CPUs/GPUs/TPUs/IPUs
377+ =========================
378+ It's trivial to use CPUs, GPUs, TPUs or IPUs in Lightning. There's **NO NEED ** to change your code, simply change the :class: `~pytorch_lightning.trainer.Trainer ` options.
377379
378380.. testcode ::
379381
@@ -423,6 +425,11 @@ Without changing a SINGLE line of your code, you can now do the following with t
423425 # using only half the training data and checking validation every quarter of a training epoch
424426 trainer = pl.Trainer(tpu_cores = 8 , precision = 16 , limit_train_batches = 0.5 , val_check_interval = 0.25 )
425427
428+ .. code-block :: python
429+
430+ # Train on IPUs
431+ trainer = pl.Trainer(ipus = 8 )
432+
426433-----------
427434
428435Checkpoints
@@ -449,7 +456,7 @@ If you prefer to do it manually, here's the equivalent
449456
450457Data flow
451458=========
452- Each loop (training, validation, test) has three hooks you can implement:
459+ Each loop (training, validation, test, predict ) has three hooks you can implement:
453460
454461- x_step
455462- x_step_end
@@ -474,8 +481,8 @@ The equivalent in Lightning is:
474481 return prediction
475482
476483
477- def training_epoch_end (self , training_step_outputs ):
478- for prediction in predictions :
484+ def training_epoch_end (self , outs ):
485+ for out in outs :
479486 ...
480487
481488 In the event that you use DP or DDP2 distributed modes (ie: split a batch across GPUs),
@@ -508,9 +515,9 @@ The lightning equivalent is:
508515 def training_step_end (self , losses ):
509516 gpu_0_loss = losses[0 ]
510517 gpu_1_loss = losses[1 ]
511- return (gpu_0_loss + gpu_1_loss) * 1 / 2
518+ return (gpu_0_loss + gpu_1_loss) / 2
512519
513- .. tip :: The validation and test loops have the same structure.
520+ .. tip :: The validation, test and prediction loops have the same structure.
514521
515522-----------------
516523
@@ -648,8 +655,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
648655 if stage in (None, "fit"):
649656 mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
650657 self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
651- if stage == (None, "test") :
658+ if stage == "test":
652659 self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
660+ if stage == "predict":
661+ self.mnist_predict = MNIST(os.getcwd(), train=False, transform=transform)
653662
654663 # return the dataloader for each split
655664 def train_dataloader(self):
@@ -664,6 +673,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
664673 mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
665674 return mnist_test
666675
676+ def predict_dataloader(self):
677+ mnist_predict = DataLoader(self.mnist_predict, batch_size=self.batch_size)
678+ return mnist_predict
679+
667680:class: `~pytorch_lightning.core.datamodule.LightningDataModule ` is designed to enable sharing and reusing data splits
668681and transforms across different projects. It encapsulates all the steps needed to process data: downloading,
669682tokenizing, processing etc.
@@ -681,11 +694,17 @@ the :class:`~pytorch_lightning.trainer.Trainer`:
681694
682695 # train
683696 trainer = pl.Trainer()
684- trainer.fit(model, dm)
697+ trainer.fit(model, datamodule = dm)
698+
699+ # validate
700+ trainer.validate(datamodule = dm)
685701
686702 # test
687703 trainer.test(datamodule = dm)
688704
705+ # predict
706+ predictions = trainer.predict(datamodule = dm)
707+
689708 DataModules are specifically useful for building models based on data. Read more on :doc: `datamodules <../extensions/datamodules >`.
690709
691710------
@@ -701,15 +720,18 @@ Lightning has many tools for debugging. Here is an example of just a few of them
701720
702721.. testcode ::
703722
704- # Automatically overfit the sane batch of your model for a sanity test
723+ # Automatically overfit the same batch of your model for a sanity test
705724 trainer = Trainer(overfit_batches=1)
706725
707726.. testcode ::
708727
709- # unit test all the code- hits every line of your code once to see if you have bugs,
728+ # unit test all the code - hits every line of your code once to see if you have bugs,
710729 # instead of waiting hours to crash on validation
711730 trainer = Trainer(fast_dev_run=True)
712731
732+ # unit test all the code - hits every line of your code with 4 batches
733+ trainer = Trainer(fast_dev_run=4)
734+
713735.. testcode ::
714736
715737 # train only 20% of an epoch
@@ -739,7 +761,7 @@ Once you define and train your first Lightning model, you might want to try othe
739761- :doc: `Automatically find a good learning rate <../advanced/lr_finder >`
740762- :ref: `Load checkpoints directly from S3 <common/weights_loading:Checkpoint Loading >`
741763- :doc: `Scale to massive compute clusters <../clouds/cluster >`
742- - :doc: `Use multiple dataloaders per train/val/test loop <../guides/data >`
764+ - :doc: `Use multiple dataloaders per train/val/test/predict loop <../guides/data >`
743765- :ref: `Use multiple optimizers to do reinforcement learning or even GANs <common/optimizers:Use multiple optimizers (like GANs) >`
744766
745767Or read our :doc: `Guide <../starter/introduction_guide >` to learn more!
0 commit comments