From eaed4b4893787261f8622569fead4487fac88ee0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Feb 2021 10:24:53 +0100 Subject: [PATCH 1/8] add to docs --- docs/source/advanced/multiple_loaders.rst | 2 ++ pytorch_lightning/trainer/trainer.py | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst index fb1aa33f80462..b61a8a1688b6d 100644 --- a/docs/source/advanced/multiple_loaders.rst +++ b/docs/source/advanced/multiple_loaders.rst @@ -16,6 +16,8 @@ Lightning supports multiple dataloaders in a few ways. ---------- +.. _multiple-training-dataloaders: + Multiple training dataloaders ----------------------------- For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 10545a075cb32..4a5c29819bf06 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,7 +15,7 @@ import warnings from itertools import count from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import torch from torch.utils.data import DataLoader @@ -425,7 +425,7 @@ def setup_trainer(self, model: LightningModule): def fit( self, model: LightningModule, - train_dataloader: Optional[DataLoader] = None, + train_dataloader: Any = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): @@ -437,8 +437,9 @@ def fit( model: Model to fit. - train_dataloader: A Pytorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + train_dataloader: Either a single Pytorch Dataloader or an collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please + see this :ref:`page ` val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped From 451ef1522d995a922e29381d2463250be77da06c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Feb 2021 12:25:58 +0100 Subject: [PATCH 2/8] update docs --- pytorch_lightning/core/hooks.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index e0b33c1219e8b..ce9efb9db1846 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -385,10 +385,12 @@ def prepare_data(self): def train_dataloader(self) -> DataLoader: """ - Implement a PyTorch DataLoader for training. + Implement one or more PyTorch DataLoaders for training. Return: - Single PyTorch :class:`~torch.utils.data.DataLoader`. + Either a single PyTorch :class:`~torch.utils.data.DataLoader` or an collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see + this :ref:`page ` The dataloader you return will not be called every epoch unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. @@ -413,7 +415,7 @@ def train_dataloader(self) -> DataLoader: There is no need to set it yourself. Example:: - + # single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) @@ -426,6 +428,22 @@ def train_dataloader(self): ) return loader + # multiple dataloaders, return as list + def train_dataloader(self): + mnist = MNIST(...) + cifar = CIFAR(...) + # each batch will be a list of tensors: [batch_mnist, batch_cifar] + return [torch.utils.data.DataLoader(dataset=mnist, batch_size=self.batch_size, shuffle=True), + torch.utils.data.DataLoader(dataset=cifar, batch_size=self.batch_size, shuffle=True)] + + # multiple dataloader, return as dict + def train_dataloader(self): + mnist = MNIST(...) + cifar = CIFAR(...) + # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} + return {'mnist': torch.utils.data.DataLoader(dataset=mnist, batch_size=self.batch_size, shuffle=True), + 'cifar': torch.utils.data.DataLoader(dataset=cifar, batch_size=self.batch_size, shuffle=True)] + """ rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer") From 4676c1a7035f55e6483617385d64963b6d21b622 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 19 Feb 2021 13:09:29 +0100 Subject: [PATCH 3/8] Apply suggestions from code review --- pytorch_lightning/core/hooks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index ce9efb9db1846..951f754192829 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -415,6 +415,7 @@ def train_dataloader(self) -> DataLoader: There is no need to set it yourself. Example:: + # single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), From 7d08ba59bd1be9e92057560d86975e9d29af4613 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Feb 2021 22:42:25 +0100 Subject: [PATCH 4/8] Update pytorch_lightning/core/hooks.py Co-authored-by: Rohit Gupta --- pytorch_lightning/core/hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 951f754192829..f2e913776e63f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -383,7 +383,7 @@ def prepare_data(self): model.test_dataloader() """ - def train_dataloader(self) -> DataLoader: + def train_dataloader(self) -> Any: """ Implement one or more PyTorch DataLoaders for training. From f93c36f3f5475fe4f5308ad0a844289ba3853f4a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Feb 2021 22:51:45 +0100 Subject: [PATCH 5/8] nested loaders --- docs/source/advanced/multiple_loaders.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst index b61a8a1688b6d..908a3692c3a45 100644 --- a/docs/source/advanced/multiple_loaders.rst +++ b/docs/source/advanced/multiple_loaders.rst @@ -88,6 +88,27 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer return loaders +Furthermore, lightning also supports that nested lists and dicts (or an combination) can +be returned + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + + loader_a = torch.utils.data.DataLoader(range(8), batch_size=4) + loader_b = torch.utils.data.DataLoader(range(16), batch_size=4) + loader_c = torch.utils.data.DataLoader(range(32), batch_size=4) + loader_c = torch.utils.data.DataLoader(range(64), batch_size=4) + + # pass loaders as a nested dict. This will create batches like this: + # {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b}, + # 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}} + loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b} + 'loaders_c_d': {'c': loader_c, 'd': loader_d}} + return loaders + ---------- Test/Val dataloaders From b7caa7eb0277ce48c98bef7cad41cdee9b5629a3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 21 Feb 2021 13:33:32 +0100 Subject: [PATCH 6/8] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- docs/source/advanced/multiple_loaders.rst | 4 ++-- pytorch_lightning/core/hooks.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst index 908a3692c3a45..2e3e3201b2181 100644 --- a/docs/source/advanced/multiple_loaders.rst +++ b/docs/source/advanced/multiple_loaders.rst @@ -88,7 +88,7 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer return loaders -Furthermore, lightning also supports that nested lists and dicts (or an combination) can +Furthermore, Lightning also supports that nested lists and dicts (or a combination) can be returned .. testcode:: @@ -105,7 +105,7 @@ be returned # pass loaders as a nested dict. This will create batches like this: # {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b}, # 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}} - loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b} + loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b}, 'loaders_c_d': {'c': loader_c, 'd': loader_d}} return loaders diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index f2e913776e63f..ca8e3243bd486 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -443,7 +443,7 @@ def train_dataloader(self): cifar = CIFAR(...) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': torch.utils.data.DataLoader(dataset=mnist, batch_size=self.batch_size, shuffle=True), - 'cifar': torch.utils.data.DataLoader(dataset=cifar, batch_size=self.batch_size, shuffle=True)] + 'cifar': torch.utils.data.DataLoader(dataset=cifar, batch_size=self.batch_size, shuffle=True)} """ rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4a5c29819bf06..cc443bc92ea5d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -437,7 +437,7 @@ def fit( model: Model to fit. - train_dataloader: Either a single Pytorch Dataloader or an collection of these + train_dataloader: Either a single PyTorch DataLoader or a collection of these (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see this :ref:`page ` From bf836d20ec32caa99dfb62359f687bee8ec4accb Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 21 Feb 2021 13:38:12 +0100 Subject: [PATCH 7/8] shorten text length --- pytorch_lightning/core/hooks.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index ca8e3243bd486..0c30360cff785 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -433,17 +433,27 @@ def train_dataloader(self): def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) + mnist_loader = torch.utils.data.DataLoader( + dataset=mnist, batch_size=self.batch_size, shuffle=True + ) + cifar_loader = torch.utils.data.DataLoader( + dataset=cifar, batch_size=self.batch_size, shuffle=True + ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] - return [torch.utils.data.DataLoader(dataset=mnist, batch_size=self.batch_size, shuffle=True), - torch.utils.data.DataLoader(dataset=cifar, batch_size=self.batch_size, shuffle=True)] + return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) + mnist_loader = torch.utils.data.DataLoader( + dataset=mnist, batch_size=self.batch_size, shuffle=True + ) + cifar_loader = torch.utils.data.DataLoader( + dataset=cifar, batch_size=self.batch_size, shuffle=True + ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} - return {'mnist': torch.utils.data.DataLoader(dataset=mnist, batch_size=self.batch_size, shuffle=True), - 'cifar': torch.utils.data.DataLoader(dataset=cifar, batch_size=self.batch_size, shuffle=True)} + return {'mnist': mnist_loader, 'cifar': cifar_loader} """ rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer") From 03a2ebb0d6e3343ac21b7f71aa53013ea9308d6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 2 Mar 2021 02:36:24 +0100 Subject: [PATCH 8/8] Update pytorch_lightning/core/hooks.py --- pytorch_lightning/core/hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 0c30360cff785..604803365298c 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -388,7 +388,7 @@ def train_dataloader(self) -> Any: Implement one or more PyTorch DataLoaders for training. Return: - Either a single PyTorch :class:`~torch.utils.data.DataLoader` or an collection of these + Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see this :ref:`page `