@@ -383,12 +383,14 @@ def prepare_data(self):
383383 model.test_dataloader()
384384 """
385385
386- def train_dataloader (self ) -> DataLoader :
386+ def train_dataloader (self ) -> Any :
387387 """
388- Implement a PyTorch DataLoader for training.
388+ Implement one or more PyTorch DataLoaders for training.
389389
390390 Return:
391- Single PyTorch :class:`~torch.utils.data.DataLoader`.
391+ Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these
392+ (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see
393+ this :ref:`page <multiple-training-dataloaders>`
392394
393395 The dataloader you return will not be called every epoch unless you set
394396 :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
@@ -414,6 +416,7 @@ def train_dataloader(self) -> DataLoader:
414416
415417 Example::
416418
419+ # single dataloader
417420 def train_dataloader(self):
418421 transform = transforms.Compose([transforms.ToTensor(),
419422 transforms.Normalize((0.5,), (1.0,))])
@@ -426,6 +429,32 @@ def train_dataloader(self):
426429 )
427430 return loader
428431
432+ # multiple dataloaders, return as list
433+ def train_dataloader(self):
434+ mnist = MNIST(...)
435+ cifar = CIFAR(...)
436+ mnist_loader = torch.utils.data.DataLoader(
437+ dataset=mnist, batch_size=self.batch_size, shuffle=True
438+ )
439+ cifar_loader = torch.utils.data.DataLoader(
440+ dataset=cifar, batch_size=self.batch_size, shuffle=True
441+ )
442+ # each batch will be a list of tensors: [batch_mnist, batch_cifar]
443+ return [mnist_loader, cifar_loader]
444+
445+ # multiple dataloader, return as dict
446+ def train_dataloader(self):
447+ mnist = MNIST(...)
448+ cifar = CIFAR(...)
449+ mnist_loader = torch.utils.data.DataLoader(
450+ dataset=mnist, batch_size=self.batch_size, shuffle=True
451+ )
452+ cifar_loader = torch.utils.data.DataLoader(
453+ dataset=cifar, batch_size=self.batch_size, shuffle=True
454+ )
455+ # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
456+ return {'mnist': mnist_loader, 'cifar': cifar_loader}
457+
429458 """
430459 rank_zero_warn ("`train_dataloader` must be implemented to be used with the Lightning Trainer" )
431460
0 commit comments