@@ -320,6 +320,8 @@ When your models need to know about the data, it's best to process the data befo
320320 1. use `prepare_data ` to download and process the dataset.
3213212. use `setup ` to do splits, and build your model internals
322322
323+ |
324+
323325.. testcode ::
324326
325327 class LitMNIST(LightningModule):
@@ -391,11 +393,11 @@ In the case of MNIST we do the following
391393
392394 for epoch in epochs:
393395 for batch in data:
394- # TRAINING STEP START
396+ # ------ TRAINING STEP START ------
395397 x, y = batch
396398 logits = model(x)
397399 loss = F.nll_loss(logits, y)
398- # TRAINING STEP END
400+ # ------ TRAINING STEP END ------
399401
400402 loss.backward()
401403 optimizer.step()
@@ -419,12 +421,13 @@ This code is not restricted which means it can be as complicated as a full seq-2
419421
420422TrainResult
421423^^^^^^^^^^^
422- Whenever you'd like more control over the outputs of the ` training_step ` use a `TrainResult ` object which can:
424+ Whenever you'd like to log, or sync values across GPUs use `TrainResult `.
423425
424426- log to Tensorboard or the other logger of your choice.
425427- log to the progress-bar.
426428- log on every step.
427429- log aggregate epoch metrics.
430+ - average values across GPUs/TPU cores
428431
429432.. code-block :: python
430433
@@ -441,6 +444,13 @@ Whenever you'd like more control over the outputs of the `training_step` use a `
441444 # equivalent
442445 result.log(' train_loss' , loss, on_step = True , on_epoch = False , prog_bar = False , logger = True , reduce_fx = torch.mean)
443446
447+ When training across accelerators (GPUs/TPUs) you can sync a metric if needed.
448+
449+ .. code-block :: python
450+
451+ # sync across GPUs / TPUs, etc...
452+ result.log(' train_loss' , loss, sync_dist = True )
453+
444454 If you are only using a training_loop (`training_step `) without a
445455validation or test loop (`validation_step `, `test_step `), you can still use EarlyStopping or automatic checkpointing
446456
@@ -460,6 +470,8 @@ So far we defined 4 key ingredients in pure PyTorch but organized the code with
4604703. Optimizer.
4614714. What happens in the training loop.
462472
473+ |
474+
463475For clarity, we'll recall that the full LightningModule now looks like this.
464476
465477.. code-block :: python
@@ -533,6 +545,9 @@ Which will generate automatic tensorboard logs.
533545
534546.. figure :: /_images/mnist_imgs/mnist_tb.png
535547 :alt: mnist CPU bar
548+ :width: 500
549+
550+ |
536551
537552But you can also use any of the `number of other loggers <loggers.rst >`_ we support.
538553
@@ -585,13 +600,20 @@ First, change the runtime to TPU (and reinstall lightning).
585600
586601.. figure :: /_images/mnist_imgs/runtime_tpu.png
587602 :alt: mnist GPU bar
603+ :width: 400
588604
589605.. figure :: /_images/mnist_imgs/restart_runtime.png
590606 :alt: mnist GPU bar
607+ :width: 400
608+
609+ |
591610
592611Next, install the required xla library (adds support for PyTorch on TPUs)
593612
613+ .. code-block :: shell
614+
594615 ! curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
616+
595617 ! python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
596618
597619 In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy
@@ -607,14 +629,18 @@ In this method we do all the preparation we need to do once (instead of on every
607629.. code-block :: python
608630
609631 class MNISTDataModule (LightningDataModule ):
632+ def __init__ (self , batch_size = 64 ):
633+ super ().__init__ ()
634+ self .batch_size = batch_size
635+
610636 def prepare_data (self ):
611637 # download only
612638 MNIST(os.getcwd(), train = True , download = True , transform = transforms.ToTensor())
613639 MNIST(os.getcwd(), train = False , download = True , transform = transforms.ToTensor())
614640
615641 def setup (self , stage ):
616642 # transform
617- transform= transforms.Compose([transforms.ToTensor(), transforms.Normalize(( 0.1307 ,), ( 0.3081 ,)) ])
643+ transform= transforms.Compose([transforms.ToTensor()])
618644 MNIST(os.getcwd(), train = True , download = False , transform = transform)
619645 MNIST(os.getcwd(), train = False , download = False , transform = transform)
620646
@@ -627,13 +653,13 @@ In this method we do all the preparation we need to do once (instead of on every
627653 self .test_dataset = mnist_test
628654
629655 def train_dataloader (self ):
630- return DataLoader(self .train_dataset, batch_size = 64 )
656+ return DataLoader(self .train_dataset, batch_size = self .batch_size )
631657
632658 def val_dataloader (self ):
633- return DataLoader(self .val_dataset, batch_size = 64 )
659+ return DataLoader(self .val_dataset, batch_size = self .batch_size )
634660
635661 def test_dataloader (self ):
636- return DataLoader(self .test_dataset, batch_size = 64 )
662+ return DataLoader(self .test_dataset, batch_size = self .batch_size )
637663
638664 The `prepare_data ` method is also a good place to do any data processing that needs to be done only
639665once (ie: download or tokenize, etc...).
@@ -653,11 +679,13 @@ You'll now see the TPU cores booting up.
653679
654680.. figure :: /_images/mnist_imgs/tpu_start.png
655681 :alt: TPU start
682+ :width: 400
656683
657684Notice the epoch is MUCH faster!
658685
659686.. figure :: /_images/mnist_imgs/tpu_fast.png
660687 :alt: TPU speed
688+ :width: 600
661689
662690----------------
663691
@@ -737,12 +765,13 @@ If you still need even more fine-grain control, define the other optional method
737765.. code-block :: python
738766
739767 def validation_step (self , batch , batch_idx ):
740- val_step_output = {' step_output' : x}
741- return val_step_output
768+ result = pl.EvalResult()
769+ result.prediction = some_prediction
770+ return result
742771
743772 def validation_epoch_end (self , val_step_outputs ):
744- for val_step_output in val_step_outputs:
745- # each object here is what you passed back at each validation_step
773+ # do something with all the predictions from each validation_step
774+ all_predictions = val_step_outputs.prediction
746775
747776----------------
748777
0 commit comments