Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/advanced/multiple_loaders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Lightning supports multiple dataloaders in a few ways.

----------

.. _multiple-training-dataloaders:

Multiple training dataloaders
-----------------------------
For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class
Expand Down Expand Up @@ -86,6 +88,27 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer

return loaders

Furthermore, Lightning also supports that nested lists and dicts (or a combination) can
be returned

.. testcode::

class LitModel(LightningModule):

def train_dataloader(self):

loader_a = torch.utils.data.DataLoader(range(8), batch_size=4)
loader_b = torch.utils.data.DataLoader(range(16), batch_size=4)
loader_c = torch.utils.data.DataLoader(range(32), batch_size=4)
loader_c = torch.utils.data.DataLoader(range(64), batch_size=4)

# pass loaders as a nested dict. This will create batches like this:
# {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b},
# 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}}
loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b},
'loaders_c_d': {'c': loader_c, 'd': loader_d}}
return loaders

----------

Test/Val dataloaders
Expand Down
35 changes: 32 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,14 @@ def prepare_data(self):
model.test_dataloader()
"""

def train_dataloader(self) -> DataLoader:
def train_dataloader(self) -> Any:
"""
Implement a PyTorch DataLoader for training.
Implement one or more PyTorch DataLoaders for training.

Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please see
this :ref:`page <multiple-training-dataloaders>`

The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
Expand All @@ -414,6 +416,7 @@ def train_dataloader(self) -> DataLoader:

Example::

# single dataloader
def train_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
Expand All @@ -426,6 +429,32 @@ def train_dataloader(self):
)
return loader

# multiple dataloaders, return as list
def train_dataloader(self):
mnist = MNIST(...)
cifar = CIFAR(...)
mnist_loader = torch.utils.data.DataLoader(
dataset=mnist, batch_size=self.batch_size, shuffle=True
)
cifar_loader = torch.utils.data.DataLoader(
dataset=cifar, batch_size=self.batch_size, shuffle=True
)
# each batch will be a list of tensors: [batch_mnist, batch_cifar]
return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
mnist = MNIST(...)
cifar = CIFAR(...)
mnist_loader = torch.utils.data.DataLoader(
dataset=mnist, batch_size=self.batch_size, shuffle=True
)
cifar_loader = torch.utils.data.DataLoader(
dataset=cifar, batch_size=self.batch_size, shuffle=True
)
# each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
return {'mnist': mnist_loader, 'cifar': cifar_loader}

"""
rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer")

Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings
from itertools import count
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -425,7 +425,7 @@ def setup_trainer(self, model: LightningModule):
def fit(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
train_dataloader: Any = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
Expand All @@ -437,8 +437,9 @@ def fit(

model: Model to fit.

train_dataloader: A Pytorch DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
train_dataloader: Either a single PyTorch DataLoader or a collection of these
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please
see this :ref:`page <multiple-training-dataloaders>`

val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped
Expand Down