Skip to content

Commit d13e5c9

Browse files
document lightiningmodule better (#2920)
* updated docs
1 parent 580a5bd commit d13e5c9

File tree

15 files changed

+1752
-842
lines changed

15 files changed

+1752
-842
lines changed

docs/source/child_modules.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ that change in the `Autoencoder` model are the init, forward, training, validati
6666
x_hat = self(representation)
6767

6868
loss = F.nll_loss(logits, y)
69-
return {f'{prefix}_loss': loss}
69+
result = pl.EvalResult()
70+
result.log(f'{prefix}_loss', loss)
71+
return result
7072

7173

7274
and we can train this using the same trainer

docs/source/hyperparameters.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ It is best practice to layer your arguments in three sections.
4242
2. Model specific arguments (layer_dim, num_layers, learning_rate, etc...)
4343
3. Program arguments (data_path, cluster_email, etc...)
4444

45+
|
46+
4547
We can do this as follows. First, in your LightningModule, define the arguments
4648
specific to that module. Remember that data splits or data paths may also be specific to
4749
a module (ie: if your project has a model that trains on Imagenet and another on CIFAR-10).

docs/source/introduction_guide.rst

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ When your models need to know about the data, it's best to process the data befo
320320
1. use `prepare_data` to download and process the dataset.
321321
2. use `setup` to do splits, and build your model internals
322322

323+
|
324+
323325
.. testcode::
324326

325327
class LitMNIST(LightningModule):
@@ -391,11 +393,11 @@ In the case of MNIST we do the following
391393
392394
for epoch in epochs:
393395
for batch in data:
394-
# TRAINING STEP START
396+
# ------ TRAINING STEP START ------
395397
x, y = batch
396398
logits = model(x)
397399
loss = F.nll_loss(logits, y)
398-
# TRAINING STEP END
400+
# ------ TRAINING STEP END ------
399401
400402
loss.backward()
401403
optimizer.step()
@@ -419,12 +421,13 @@ This code is not restricted which means it can be as complicated as a full seq-2
419421

420422
TrainResult
421423
^^^^^^^^^^^
422-
Whenever you'd like more control over the outputs of the `training_step` use a `TrainResult` object which can:
424+
Whenever you'd like to log, or sync values across GPUs use `TrainResult`.
423425

424426
- log to Tensorboard or the other logger of your choice.
425427
- log to the progress-bar.
426428
- log on every step.
427429
- log aggregate epoch metrics.
430+
- average values across GPUs/TPU cores
428431

429432
.. code-block:: python
430433
@@ -441,6 +444,13 @@ Whenever you'd like more control over the outputs of the `training_step` use a `
441444
# equivalent
442445
result.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True, reduce_fx=torch.mean)
443446
447+
When training across accelerators (GPUs/TPUs) you can sync a metric if needed.
448+
449+
.. code-block:: python
450+
451+
# sync across GPUs / TPUs, etc...
452+
result.log('train_loss', loss, sync_dist=True)
453+
444454
If you are only using a training_loop (`training_step`) without a
445455
validation or test loop (`validation_step`, `test_step`), you can still use EarlyStopping or automatic checkpointing
446456

@@ -460,6 +470,8 @@ So far we defined 4 key ingredients in pure PyTorch but organized the code with
460470
3. Optimizer.
461471
4. What happens in the training loop.
462472

473+
|
474+
463475
For clarity, we'll recall that the full LightningModule now looks like this.
464476

465477
.. code-block:: python
@@ -533,6 +545,9 @@ Which will generate automatic tensorboard logs.
533545

534546
.. figure:: /_images/mnist_imgs/mnist_tb.png
535547
:alt: mnist CPU bar
548+
:width: 500
549+
550+
|
536551
537552
But you can also use any of the `number of other loggers <loggers.rst>`_ we support.
538553

@@ -585,13 +600,20 @@ First, change the runtime to TPU (and reinstall lightning).
585600

586601
.. figure:: /_images/mnist_imgs/runtime_tpu.png
587602
:alt: mnist GPU bar
603+
:width: 400
588604

589605
.. figure:: /_images/mnist_imgs/restart_runtime.png
590606
:alt: mnist GPU bar
607+
:width: 400
608+
609+
|
591610
592611
Next, install the required xla library (adds support for PyTorch on TPUs)
593612

613+
.. code-block:: shell
614+
594615
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
616+
595617
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
596618
597619
In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy
@@ -607,14 +629,18 @@ In this method we do all the preparation we need to do once (instead of on every
607629
.. code-block:: python
608630
609631
class MNISTDataModule(LightningDataModule):
632+
def __init__(self, batch_size=64):
633+
super().__init__()
634+
self.batch_size = batch_size
635+
610636
def prepare_data(self):
611637
# download only
612638
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
613639
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
614640
615641
def setup(self, stage):
616642
# transform
617-
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
643+
transform=transforms.Compose([transforms.ToTensor()])
618644
MNIST(os.getcwd(), train=True, download=False, transform=transform)
619645
MNIST(os.getcwd(), train=False, download=False, transform=transform)
620646
@@ -627,13 +653,13 @@ In this method we do all the preparation we need to do once (instead of on every
627653
self.test_dataset = mnist_test
628654
629655
def train_dataloader(self):
630-
return DataLoader(self.train_dataset, batch_size=64)
656+
return DataLoader(self.train_dataset, batch_size=self.batch_size)
631657
632658
def val_dataloader(self):
633-
return DataLoader(self.val_dataset, batch_size=64)
659+
return DataLoader(self.val_dataset, batch_size=self.batch_size)
634660
635661
def test_dataloader(self):
636-
return DataLoader(self.test_dataset, batch_size=64)
662+
return DataLoader(self.test_dataset, batch_size=self.batch_size)
637663
638664
The `prepare_data` method is also a good place to do any data processing that needs to be done only
639665
once (ie: download or tokenize, etc...).
@@ -653,11 +679,13 @@ You'll now see the TPU cores booting up.
653679

654680
.. figure:: /_images/mnist_imgs/tpu_start.png
655681
:alt: TPU start
682+
:width: 400
656683

657684
Notice the epoch is MUCH faster!
658685

659686
.. figure:: /_images/mnist_imgs/tpu_fast.png
660687
:alt: TPU speed
688+
:width: 600
661689

662690
----------------
663691

@@ -737,12 +765,13 @@ If you still need even more fine-grain control, define the other optional method
737765
.. code-block:: python
738766
739767
def validation_step(self, batch, batch_idx):
740-
val_step_output = {'step_output': x}
741-
return val_step_output
768+
result = pl.EvalResult()
769+
result.prediction = some_prediction
770+
return result
742771
743772
def validation_epoch_end(self, val_step_outputs):
744-
for val_step_output in val_step_outputs:
745-
# each object here is what you passed back at each validation_step
773+
# do something with all the predictions from each validation_step
774+
all_predictions = val_step_outputs.prediction
746775
747776
----------------
748777

0 commit comments

Comments
 (0)