From 98aafa1eacb0d41964feb31ba0367cae9e205167 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 May 2020 16:25:55 +0200 Subject: [PATCH 01/37] add support for wrong dtype in apply_func --- pytorch_lightning/utilities/apply_func.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 775c22dbbfa0a..7225ea98d5b94 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -27,7 +27,8 @@ Batch = type(None) -def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: +def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, + wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any: """ Recursively applies a function to all elements of a certain dtype. @@ -36,6 +37,8 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable dtype: the given function will be applied to all elements of this dtype function: the function to apply *args: positional arguments (will be forwarded to calls of ``function``) + wrong_dtype: the given function won't be applied if this type is specified and the given collections is of + the :attr:`wrong_type` even if it is of type :attr`dtype` **kwargs: keyword arguments (will be forwarded to calls of ``function``) Returns: @@ -45,7 +48,7 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable elem_type = type(data) # Breaking condition - if isinstance(data, dtype): + if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): return function(data, *args, **kwargs) # Recursively apply to collection items From f7e2405c12ea4841cac0fa567677693f35200c2b Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 May 2020 16:26:30 +0200 Subject: [PATCH 02/37] apply loader resetting to possible collection of loaders --- pytorch_lightning/trainer/data_loading.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index f2bcb1d1760d6..0cc9198dd1dde 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -150,7 +150,8 @@ def reset_train_dataloader(self, model: LightningModule) -> None: self.num_training_batches = 0 # automatically add samplers - self.train_dataloader = self.auto_add_sampler(self.train_dataloader, shuffle=True) + self.train_dataloader = apply_to_collection( + self.train_dataloader, DataLoader, self.auto_add_sampler, train=True) self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') self._worker_check(self.train_dataloader, 'train dataloader') From dc1f0d0986a0eec0e6ff4c3ff152b1273dbb0b9f Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 May 2020 16:26:44 +0200 Subject: [PATCH 03/37] add combined loader iter class --- pytorch_lightning/trainer/supporters.py | 52 +++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 445ddbd87686c..95e2bc4d65f7f 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -18,6 +18,9 @@ import torch from pytorch_lightning.utilities.cloud_io import get_filesystem from torch import Tensor +from pytorch_lightning.utilities.apply_func import apply_to_collection +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import Any class TensorRunningAccum(object): @@ -176,3 +179,52 @@ def to_disk(self) -> None: # Write predictions for current file to disk with fs.open(filepath, "wb") as fp: torch.save(outputs, fp) + return getattr(self.memory[:self.current_idx], how)() + + +class CombinedLoaderIterator(object): + def __init__(self, loaders: Any): + self.loaders = loaders + self._loader_iters = None + + @property + def loader_iters(self): + if self._loader_iters is None: + self._loader_iters = self.create_loader_iters(self.loaders) + + return self._loader_iters + + def __iter__(self): + return self + + def __next__(self): + return self.request_next_batch(self.loader_iters) + + @staticmethod + def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]): + return apply_to_collection(loader_iters, Iterator, next) + + @staticmethod + def _calc_num_batches(loader_iters): + all_lengths = apply_to_collection(loader_iters, Iterator, len) + + if isinstance(all_lengths, int): + return all_lengths + + elif isinstance(all_lengths, Mapping): + return min(all_lengths.values()) + + elif isinstance(all_lengths, Sequence): + return min(all_lengths) + + raise TypeError(f'Got Type {type(all_lengths).__name__}, but expected one of Sequence, int or Mapping') + + @staticmethod + def create_loader_iters(loaders: Union[Any, Iterator, + Sequence, Mapping]) -> Union[Any, Iterator, Sequence, Mapping]: + + # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences + return apply_to_collection(loaders, Iterable, iter, wrong_dtype=Sequence) + + def __len__(self): + return self._calc_num_batches(self.loader_iters) From 4ab644f87e2f577184381b2788c35b67ba26b066 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 May 2020 16:27:07 +0200 Subject: [PATCH 04/37] integrate combined loader iter to training loop --- pytorch_lightning/trainer/training_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1637e3504dd0d..b88851460cd6f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -541,15 +541,15 @@ def run_training_epoch(self): model = self.trainer.get_model() # modify dataloader if needed (ddp, etc...) - train_dataloader = self.trainer.accelerator_backend.process_dataloader(self.trainer.train_dataloader) + train_dataloader = self.trainer.accelerator_backend.process_dataloader(CombinedLoaderIterator(self.trainer.train_dataloader)) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] - # enable profiling for the dataloader - train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) + self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 should_check_val = False + for batch_idx, (batch, is_last_batch) in train_dataloader: self.trainer.batch_idx = batch_idx From 6b8f3176c8d2eb8677803ea97576ffc26833f0b0 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 1 Jun 2020 13:41:59 +0200 Subject: [PATCH 05/37] fix imports --- pytorch_lightning/utilities/apply_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 7225ea98d5b94..db81debc17db2 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -16,7 +16,7 @@ from abc import ABC from collections.abc import Mapping, Sequence from copy import copy -from typing import Any, Callable, Union +from typing import Any, Callable, Union, Optional import torch From 04ce0ef654ef818c7fec363cfd659a5a93a7cd2c Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 1 Jun 2020 14:14:08 +0200 Subject: [PATCH 06/37] fix imports --- pytorch_lightning/trainer/supporters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 95e2bc4d65f7f..9fc1070417fc3 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -20,7 +20,7 @@ from torch import Tensor from pytorch_lightning.utilities.apply_func import apply_to_collection from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import Any +from typing import Any, Union class TensorRunningAccum(object): From aa9b153acc9b4a47f7b926baeea75429cad7b4a6 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 29 Jun 2020 15:32:12 +0200 Subject: [PATCH 07/37] finish supporters --- pytorch_lightning/trainer/supporters.py | 129 ++++++++++++++++++++++-- 1 file changed, 120 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 9fc1070417fc3..1a8d19b26de51 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -182,31 +182,142 @@ def to_disk(self) -> None: return getattr(self.memory[:self.current_idx], how)() +class CycleIterator(object): + """ + Iterator for restarting a dataloader if it runs out of samples + """ + def __init__(self, loader, length: int = None): + """ + + Args: + loader: the loader to restart for cyclic (and optionally infinite) sampling + length: the number of batches to sample (with restarted loaders if necessary) before raising StopIteration + if None: infinite + """ + if length is None: + length = float('inf') + + self.length = length + self.loader = loader + self._loader_iter = None + self.counter = 0 + + def __iter__(self): + """ + Creates the internal iterator and returns self + Return: + CycleIterator: self + + """ + self._loader_iter = iter(self.loader) + return self + + def __next__(self) -> Any: + """ + Fetches the next batch from internal dataloader and restarts + it if necessary + + Return: + Any: the resulting batch + + Raises: + StopIteration: if more then :attr:`length` batches have been returned + + """ + if self.counter >= len(self): + raise StopIteration + + try: + return next(self._loader_iter) + + except StopIteration: + self._loader_iter = iter(self.loader) + return next(self._loader_iter) + finally: + self.counter += 1 + + def __len__(self) -> int: + return self.length + + class CombinedLoaderIterator(object): - def __init__(self, loaders: Any): + """ + Combines different dataloaders and allows sampling in parallel + """ + SUPPORTED_MODES = ('min_size', 'max_size_cycle') + + def __init__(self, loaders: Any, mode='min_size'): + """ + + Args: + loaders: the loaders to sample from. Can be all kind of collection + mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and + 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones. + """ self.loaders = loaders self._loader_iters = None + if mode not in self.SUPPORTED_MODES: + raise ValueError(f"Invalid Mode: {mode}") + + self.mode = mode + + if self.mode == 'max_size_cycle': + self._wrap_loaders_max_size_cycle() + + def _wrap_loaders_max_size_cycle(self) -> Any: + """ + Wraps all loaders to make sure they are cycled until the longest loader is exhausted + + Return: + Any: the wrapped loaders + """ + all_lengths = apply_to_collection(self.loaders, Iterable, len, + wrong_dtype=(Sequence, Mapping)) + if isinstance(all_lengths, int): + length = all_lengths + + elif isinstance(all_lengths, Mapping): + length = max(all_lengths.values()) + + elif isinstance(all_lengths, Sequence): + length = max(all_lengths) + + if isinstance(self.loaders, Mapping): + self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) + for k, v in self.loaders.items()}) + + elif isinstance(self.loaders, Sequence): + self.loaders = type(self.loaders)([CycleIterator(v, length=length) + for v in self.loaders]) + + # dataloaders are iterable but not sequence + elif isinstance(Iterable): + self.loaders = CycleIterator(self.loaders, length=length) + else: + raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}') + @property - def loader_iters(self): + def loader_iters(self) -> Any: if self._loader_iters is None: self._loader_iters = self.create_loader_iters(self.loaders) return self._loader_iters - def __iter__(self): + def __iter__(self) -> Any: return self - def __next__(self): + def __next__(self) -> Any: return self.request_next_batch(self.loader_iters) @staticmethod - def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]): + def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any: return apply_to_collection(loader_iters, Iterator, next) @staticmethod - def _calc_num_batches(loader_iters): - all_lengths = apply_to_collection(loader_iters, Iterator, len) + def _calc_num_batches(loaders) -> int: + all_lengths = apply_to_collection(loaders, Iterable, len, + wrong_dtype=(Sequence, Mapping)) if isinstance(all_lengths, int): return all_lengths @@ -224,7 +335,7 @@ def create_loader_iters(loaders: Union[Any, Iterator, Sequence, Mapping]) -> Union[Any, Iterator, Sequence, Mapping]: # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences - return apply_to_collection(loaders, Iterable, iter, wrong_dtype=Sequence) + return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping)) def __len__(self): - return self._calc_num_batches(self.loader_iters) + return self._calc_num_batches(self.loaders) From 9525024948bf60f2e9df040b55810d9d7340b4c4 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 29 Jun 2020 15:32:31 +0200 Subject: [PATCH 08/37] add tests for supporters --- tests/trainer/test_supporters.py | 76 ++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/trainer/test_supporters.py diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py new file mode 100644 index 0000000000000..2eebd5d2b0022 --- /dev/null +++ b/tests/trainer/test_supporters.py @@ -0,0 +1,76 @@ +from collections import Sequence + +import torch + +from pytorch_lightning.trainer.supporters import CycleIterator, CombinedLoaderIterator + + +def test_cycle_iterator(): + iterator = CycleIterator(range(100), 1000) + assert len(iterator) == 1000 + for idx, item in enumerate(iterator): + assert item < 100 + + assert idx == len(iterator) - 1 + + +def test_combined_loader_iterator_dict_min_size(): + loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), + 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + + combined_loader = CombinedLoaderIterator(loaders, 'min_size') + + assert len(combined_loader) == min([len(v) for v in loaders.values()]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, dict) + assert len(item) == 2 + assert 'a' in item and 'b' in item + + assert idx == len(combined_loader) - 1 + + +def test_combined_loader_iterator_dict_max_size_cycle(): + loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), + 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + + combined_loader = CombinedLoaderIterator(loaders, 'max_size_cycle') + + assert len(combined_loader) == max([len(v) for v in loaders.values()]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, dict) + assert len(item) == 2 + assert 'a' in item and 'b' in item + + assert idx == len(combined_loader) - 1 + + +def test_combined_loader_iterator_sequence_min_size(): + loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), + torch.utils.data.DataLoader(range(20), batch_size=5)] + + combined_loader = CombinedLoaderIterator(loaders, 'min_size') + + assert len(combined_loader) == min([len(v) for v in loaders]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, Sequence) + assert len(item) == 2 + + assert idx == len(combined_loader) - 1 + + +def test_combined_loader_iterator_sequence_max_size_cycle(): + loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), + torch.utils.data.DataLoader(range(20), batch_size=5)] + + combined_loader = CombinedLoaderIterator(loaders, 'max_size_cycle') + + assert len(combined_loader) == max([len(v) for v in loaders]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, Sequence) + assert len(item) == 2 + + assert idx == len(combined_loader) - 1 From dab57afa92504d7cf2d0f7ea635d1347c0fdd560 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 29 Jun 2020 15:32:58 +0200 Subject: [PATCH 09/37] add test for model with multiple loaders --- tests/trainer/test_dataloaders.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 50c426c174349..318ed7bebea15 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -33,7 +33,6 @@ def test_fit_train_loader_only(tmpdir): - model = EvalModelTemplate() train_dataloader = model.train_dataloader() @@ -52,7 +51,6 @@ def test_fit_train_loader_only(tmpdir): def test_fit_val_loader_only(tmpdir): - model = EvalModelTemplate() train_dataloader = model.train_dataloader() val_dataloader = model.val_dataloader() @@ -857,6 +855,29 @@ def train_dataloader(self): assert 1 == result +def test_fit_multiple_train_loaders(tmpdir): + class MutipleLoaderModel(EvalModelTemplate): + def train_dataloader(self): + return {'a': super().train_dataloader(), + 'b': super().train_dataloader()} + + def training_step(self, batch, batch_idx, optimizer_idx=None): + assert isinstance(batch, dict) + assert len(batch) == 2 + assert 'a' in batch and 'b' in batch + return super().training_step(batch=batch['a'], batch_idx=batch_idx, + optimizer_idx=optimizer_idx) + + hparams = EvalModelTemplate.get_default_hparams() + + model = MutipleLoaderModel(**hparams) + + trainer = Trainer( + fast_dev_run=True, default_root_dir=tmpdir + ) + assert 1 == trainer.fit(model) + + @pytest.mark.parametrize('check_interval', [1.0]) def test_val_dataloader_not_implemented_error(tmpdir, check_interval): """Test not_implemented_error data loader (e.g. IterableDataset)""" From 2f867dc7ba344fad094790dc54f5a96fa7efdfa5 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 29 Jun 2020 15:33:44 +0200 Subject: [PATCH 10/37] fix trainer integration --- pytorch_lightning/trainer/trainer.py | 3 ++- pytorch_lightning/trainer/training_loop.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 35da90625adef..c4540c41a85bb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -134,6 +134,7 @@ def __init__( automatic_optimization: Optional[bool] = None, move_metrics_to_cpu: bool = False, enable_pl_optimizer: bool = True, + multiple_trainloader_mode: str = 'max_size_cycle', ): r""" Customize every aspect of training via flags @@ -305,7 +306,7 @@ def __init__( self.tuner = Tuner(self) self.accelerator_backend = None self.evaluation_loop = EvaluationLoop(self) - self.train_loop = TrainLoop(self) + self.train_loop = TrainLoop(self, multiple_trainloader_mode) self.plugin_connector = PluginConnector(self) # training state diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b88851460cd6f..26ebb6a4dcfa0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -34,7 +34,7 @@ class TrainLoop: - def __init__(self, trainer): + def __init__(self, trainer, multiple_trainloader_mode): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None @@ -45,6 +45,7 @@ def __init__(self, trainer): self.automatic_optimization = True self._curr_step_result = None self._cur_grad_norm_dict = None + self._multiple_trainloader_mode def on_trainer_init( self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization @@ -541,7 +542,7 @@ def run_training_epoch(self): model = self.trainer.get_model() # modify dataloader if needed (ddp, etc...) - train_dataloader = self.trainer.accelerator_backend.process_dataloader(CombinedLoaderIterator(self.trainer.train_dataloader)) + train_dataloader = self.trainer.accelerator_backend.process_dataloader(CombinedLoaderIterator(self.trainer.train_dataloader, mode=self._multiple_trainloader_mode)) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] From 0ae527bb04cc4037b25ee665f825f7d1ca8bcec7 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 29 Jun 2020 16:15:06 +0200 Subject: [PATCH 11/37] fix instance check --- pytorch_lightning/trainer/supporters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 1a8d19b26de51..ca6b1b07cfcd1 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -292,7 +292,7 @@ def _wrap_loaders_max_size_cycle(self) -> Any: for v in self.loaders]) # dataloaders are iterable but not sequence - elif isinstance(Iterable): + elif isinstance(self.loaders, Iterable): self.loaders = CycleIterator(self.loaders, length=length) else: raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}') From 608f503b553359bb59d109c293956f053473e5f5 Mon Sep 17 00:00:00 2001 From: Christofer Fransson Date: Sat, 10 Oct 2020 10:25:11 +0200 Subject: [PATCH 12/37] Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 --- pytorch_lightning/trainer/data_loading.py | 12 ++++- .../trainer/train_loader_patch.py | 48 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 pytorch_lightning/trainer/train_loader_patch.py diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 0cc9198dd1dde..5b86a2d7bcce6 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -29,6 +29,14 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils +from copy import deepcopy +from typing import Iterable +from pytorch_lightning.utilities.apply_func import apply_to_collection + +TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() +from pytorch_lightning.trainer.train_loader_patch import MagicClass + class TrainerDataLoadingMixin(ABC): @@ -308,9 +316,9 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: Returns: The dataloader """ - dataloader = dataloader_fx() + dataloader = MagicClass(dataloader_fx()) dataloader = self._flatten_dl_only(dataloader) - + if self.accelerator_backend is not None: self.accelerator_backend.barrier('get_dataloaders') return dataloader diff --git a/pytorch_lightning/trainer/train_loader_patch.py b/pytorch_lightning/trainer/train_loader_patch.py new file mode 100644 index 0000000000000..7038a2655c729 --- /dev/null +++ b/pytorch_lightning/trainer/train_loader_patch.py @@ -0,0 +1,48 @@ +''' + This patch solves two problems discussed in + https://github.com/PyTorchLightning/pytorch-lightning/pull/1959 + + The function train_dataloader can either return a single instance of + torch.utils.data.DataLoader or a dictionary of dataloaders. + + This patch fixes the length and iteration issus + and make the rest of the code oblivious of the underlying data structure. + + I will keep the name of the class but a better name is probable advisable + + @christofer-f +''' + +import itertools + +def get_len(d): + if isinstance(d, dict): + v = max(d.items(), key=lambda x: len(x[1])) + return len(v[1]) + else: + return len(d) + +class MagicClass(object): + def __init__(self, data) -> None: + super(object, self).__init__() + self.d = data + self.l = get_len(data) + + def __len__(self) -> int: + return get_len(self.d) + + def __iter__(self): + if isinstance(self.d, dict): + gen = {} + for k,v in self.d.items(): + gen[k] = itertools.cycle(v) + for i in range(self.l): + rv = {} + for k,v in self.d.items(): + rv[k] = next(gen[k]) + yield rv + else: + gen = itertools.cycle(self.d) + for i in range(self.l): + batch = next(gen) + yield batch \ No newline at end of file From a1ced87c956db1db39160c4c30faab2569f19341 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Wed, 28 Oct 2020 10:35:50 +0100 Subject: [PATCH 13/37] rename class --- pytorch_lightning/trainer/data_loading.py | 3 ++- pytorch_lightning/trainer/train_loader_patch.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 5b86a2d7bcce6..3877261e68676 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -33,6 +33,7 @@ from copy import deepcopy from typing import Iterable from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.trainer.train_loader_patch import MultiIterator TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() from pytorch_lightning.trainer.train_loader_patch import MagicClass @@ -316,7 +317,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: Returns: The dataloader """ - dataloader = MagicClass(dataloader_fx()) + dataloader = MultiIterator(dataloader_fx()) dataloader = self._flatten_dl_only(dataloader) if self.accelerator_backend is not None: diff --git a/pytorch_lightning/trainer/train_loader_patch.py b/pytorch_lightning/trainer/train_loader_patch.py index 7038a2655c729..7ec96b04c6761 100644 --- a/pytorch_lightning/trainer/train_loader_patch.py +++ b/pytorch_lightning/trainer/train_loader_patch.py @@ -22,7 +22,7 @@ def get_len(d): else: return len(d) -class MagicClass(object): +class MultiIterator(object): def __init__(self, data) -> None: super(object, self).__init__() self.d = data From eea6aae4a2edcea64cf7626989f8e9f08017b3ec Mon Sep 17 00:00:00 2001 From: YI-LIN SUNG Date: Mon, 30 Nov 2020 20:45:48 +0800 Subject: [PATCH 14/37] Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing --- pytorch_lightning/trainer/data_loading.py | 19 +- pytorch_lightning/trainer/supporters.py | 216 +++++++++++++++--- .../trainer/train_loader_patch.py | 92 +++++--- pytorch_lightning/trainer/trainer.py | 5 + pytorch_lightning/trainer/training_loop.py | 7 +- pytorch_lightning/utilities/data.py | 11 + tests/base/model_train_dataloaders.py | 7 +- tests/base/model_train_steps.py | 30 +++ tests/trainer/test_dataloaders.py | 85 +++++-- tests/trainer/test_supporters.py | 88 ++++++- tests/trainer/test_train_loader_patch.py | 69 ++++++ 11 files changed, 535 insertions(+), 94 deletions(-) create mode 100644 tests/trainer/test_train_loader_patch.py diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 3877261e68676..21768808215bb 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -34,6 +34,7 @@ from typing import Iterable from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.trainer.train_loader_patch import MultiIterator +from pytorch_lightning.trainer.supporters import CombinedLoaderIterator, CombinedLoader TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() from pytorch_lightning.trainer.train_loader_patch import MagicClass @@ -146,6 +147,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: model: The current `LightningModule` """ self.train_dataloader = self.request_dataloader(model.train_dataloader) + if (self.overfit_batches > 0): if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.' @@ -156,14 +158,19 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # debugging self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader]) - self.num_training_batches = 0 - # automatically add samplers self.train_dataloader = apply_to_collection( - self.train_dataloader, DataLoader, self.auto_add_sampler, train=True) + self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True) + + # check the workers recursively + apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader') + + # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches + self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode) + + self.num_training_batches = 0 self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') - self._worker_check(self.train_dataloader, 'train dataloader') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) @@ -317,9 +324,9 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: Returns: The dataloader """ - dataloader = MultiIterator(dataloader_fx()) + dataloader = dataloader_fx() dataloader = self._flatten_dl_only(dataloader) - + if self.accelerator_backend is not None: self.accelerator_backend.barrier('get_dataloaders') return dataloader diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index ca6b1b07cfcd1..b48ec65952692 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -18,7 +18,9 @@ import torch from pytorch_lightning.utilities.cloud_io import get_filesystem from torch import Tensor +from torch.utils.data import Dataset from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.data import get_len from collections.abc import Iterable, Iterator, Mapping, Sequence from typing import Any, Union @@ -179,20 +181,20 @@ def to_disk(self) -> None: # Write predictions for current file to disk with fs.open(filepath, "wb") as fp: torch.save(outputs, fp) - return getattr(self.memory[:self.current_idx], how)() class CycleIterator(object): """ Iterator for restarting a dataloader if it runs out of samples """ - def __init__(self, loader, length: int = None): + def __init__(self, loader: Any, length: int = None): """ Args: loader: the loader to restart for cyclic (and optionally infinite) sampling length: the number of batches to sample (with restarted loaders if necessary) before raising StopIteration if None: infinite + """ if length is None: length = float('inf') @@ -202,13 +204,16 @@ def __init__(self, loader, length: int = None): self._loader_iter = None self.counter = 0 - def __iter__(self): + def __iter__(self) -> Any: """ + Creates the internal iterator and returns self - Return: + + Returns: CycleIterator: self """ + self.counter = 0 self._loader_iter = iter(self.loader) return self @@ -217,14 +222,15 @@ def __next__(self) -> Any: Fetches the next batch from internal dataloader and restarts it if necessary - Return: + Returns: Any: the resulting batch Raises: StopIteration: if more then :attr:`length` batches have been returned """ - if self.counter >= len(self): + # Note: if self.length is `inf`, then the iterator will never stop + if self.counter >= self.__len__(): raise StopIteration try: @@ -233,6 +239,7 @@ def __next__(self) -> Any: except StopIteration: self._loader_iter = iter(self.loader) return next(self._loader_iter) + finally: self.counter += 1 @@ -240,22 +247,100 @@ def __len__(self) -> int: return self.length -class CombinedLoaderIterator(object): +class CombinedDataset(object): + """ + Combine multiple datasets and compute their statistics + """ + def __init__(self, datasets: Union[Sequence, Mapping]): + """ + + Args: + datasets: a sequence/mapping datasets. Can be a collections of torch.utils.Dataset, + Iterable or even None. + + """ + self.datasets = datasets + + self.max_len = self._calc_num_data(self.datasets, 'max') + self.min_len = self._calc_num_data(self.datasets, 'min') + + @staticmethod + def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str = 'min') -> Union[int, float]: + """ + Compute the length of `CombinedDataset` according to the `mode`. + + Args: + datasets: a sequence/mapping datasets. Can be a collections of torch.utils.data.Dataset, + Iterable or even None. + mode (str): 'min' or 'max'. Determine `CombinedDataset`'s length is the maximum or minimum of + the datasets. + + Returns: + length (int): the length of `CombinedDataset` + + """ + if mode not in ['min', 'max']: + raise ValueError(f"Invalid Mode: {mode}") + + # extract the lengths + all_lengths = apply_to_collection(datasets, (Dataset, Iterable, type(None)), get_len, + wrong_dtype=(Sequence, Mapping)) + + compute_func = eval(mode) + + if isinstance(all_lengths, (int, float)): + length = all_lengths + + elif isinstance(all_lengths, Mapping): + length = compute_func(all_lengths.values()) + + elif isinstance(all_lengths, Sequence): + length = compute_func(all_lengths) + + return length + + def __len__(self) -> int: + """Return the minimum length of the datasets.""" + return self.min_len + + +class CombinedLoader(object): """ - Combines different dataloaders and allows sampling in parallel + Combines different dataloaders and allows sampling in parallel. + + Examples: + >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), \ + 'b': torch.utils.data.DataLoader(range(15), batch_size=5)} + >>> combined_loader = CombinedLoader(loaders, 'max_size_cycle') + >>> for item in combined_loader: \ + print(item) + {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} + {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} + {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])} + >>> combined_loader = CombinedLoader(loaders, 'min_size') + >>> for item in combined_loader: \ + print(item) + {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} + {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} + """ SUPPORTED_MODES = ('min_size', 'max_size_cycle') - def __init__(self, loaders: Any, mode='min_size'): + def __init__(self, loaders: Any, mode: str = 'min_size'): """ Args: loaders: the loaders to sample from. Can be all kind of collection mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones. + """ self.loaders = loaders - self._loader_iters = None + + datasets = apply_to_collection(self.loaders, Iterable, getattr, 'dataset', None, + wrong_dtype=(Sequence, Mapping)) + # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader + self.dataset = CombinedDataset(datasets) if mode not in self.SUPPORTED_MODES: raise ValueError(f"Invalid Mode: {mode}") @@ -265,16 +350,24 @@ def __init__(self, loaders: Any, mode='min_size'): if self.mode == 'max_size_cycle': self._wrap_loaders_max_size_cycle() + @property + def sampler(self) -> Union[Iterable, Sequence, Mapping]: + """Return a collections of samplers extracting from loaders.""" + return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None, + wrong_dtype=(Sequence, Mapping)) + def _wrap_loaders_max_size_cycle(self) -> Any: """ Wraps all loaders to make sure they are cycled until the longest loader is exhausted - Return: + Returns: Any: the wrapped loaders + """ - all_lengths = apply_to_collection(self.loaders, Iterable, len, + all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) - if isinstance(all_lengths, int): + + if isinstance(all_lengths, (int, float)): length = all_lengths elif isinstance(all_lengths, Mapping): @@ -293,12 +386,66 @@ def _wrap_loaders_max_size_cycle(self) -> Any: # dataloaders are iterable but not sequence elif isinstance(self.loaders, Iterable): - self.loaders = CycleIterator(self.loaders, length=length) + # only one dataloader, just keep it the same. + pass else: raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}') + def __iter__(self) -> Any: + """ + Create and return an iterator, `CombinedLoaderIterator`, for the combined loader. + """ + return CombinedLoaderIterator(self.loaders) + + @staticmethod + def _calc_num_batches(loaders: Any) -> Union[int, float]: + """ + Compute the length (aka the number of batches) of `CombinedLoader`. + + Args: + loaders: a collections of loaders. + + Returns: + length (int): the minimum length of loaders + + """ + all_lengths = apply_to_collection(loaders, Iterable, get_len, + wrong_dtype=(Sequence, Mapping)) + + if isinstance(all_lengths, (int, float)): + return all_lengths + + elif isinstance(all_lengths, Mapping): + return min(all_lengths.values()) + + elif isinstance(all_lengths, Sequence): + return min(all_lengths) + + raise TypeError(f'Got Type {type(all_lengths).__name__}, but expected one of Sequence, int or Mapping') + + def __len__(self) -> int: + return self._calc_num_batches(self.loaders) + + +class CombinedLoaderIterator(object): + """ + Custom Iterator returning data from multple loaders, and allows sampling in parallel + """ + def __init__(self, loaders: Any): + """ + + Args: + loaders: the loaders to sample from. Can be all kind of collection + + """ + self.loaders = loaders + self._loader_iters = None + @property def loader_iters(self) -> Any: + """ + Get the `_loader_iters` and create one if it is None. + """ if self._loader_iters is None: self._loader_iters = self.create_loader_iters(self.loaders) @@ -308,34 +455,41 @@ def __iter__(self) -> Any: return self def __next__(self) -> Any: + """ + Fetches the next batch from multiple data loaders + + Returns: + Any: a collections of batch data + + """ return self.request_next_batch(self.loader_iters) @staticmethod def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any: - return apply_to_collection(loader_iters, Iterator, next) - - @staticmethod - def _calc_num_batches(loaders) -> int: - all_lengths = apply_to_collection(loaders, Iterable, len, - wrong_dtype=(Sequence, Mapping)) - - if isinstance(all_lengths, int): - return all_lengths + """ + Return the batch of data from multiple iterators. - elif isinstance(all_lengths, Mapping): - return min(all_lengths.values()) + Args: + loader_iters: a collections of iterators - elif isinstance(all_lengths, Sequence): - return min(all_lengths) + Returns + Any: a collections of batch data - raise TypeError(f'Got Type {type(all_lengths).__name__}, but expected one of Sequence, int or Mapping') + """ + return apply_to_collection(loader_iters, Iterator, next) @staticmethod def create_loader_iters(loaders: Union[Any, Iterator, Sequence, Mapping]) -> Union[Any, Iterator, Sequence, Mapping]: + """ + Create and return a collection of iterators from loaders. + + Args: + loaderss: a collections of loaders + + Returns + a collections of iterators + """ # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping)) - - def __len__(self): - return self._calc_num_batches(self.loaders) diff --git a/pytorch_lightning/trainer/train_loader_patch.py b/pytorch_lightning/trainer/train_loader_patch.py index 7ec96b04c6761..60de3aaa4d9ff 100644 --- a/pytorch_lightning/trainer/train_loader_patch.py +++ b/pytorch_lightning/trainer/train_loader_patch.py @@ -1,11 +1,11 @@ ''' - This patch solves two problems discussed in + This patch solves two problems discussed in https://github.com/PyTorchLightning/pytorch-lightning/pull/1959 - The function train_dataloader can either return a single instance of + The function train_dataloader can either return a single instance of torch.utils.data.DataLoader or a dictionary of dataloaders. - This patch fixes the length and iteration issus + This patch fixes the length and iteration issus and make the rest of the code oblivious of the underlying data structure. I will keep the name of the class but a better name is probable advisable @@ -14,35 +14,69 @@ ''' import itertools - -def get_len(d): - if isinstance(d, dict): - v = max(d.items(), key=lambda x: len(x[1])) - return len(v[1]) - else: - return len(d) + +from typing import Any, Union +from collections.abc import Iterable, Iterator, Mapping, Sequence + +from torch.utils.data import DataLoader + +from pytorch_lightning.utilities.data import get_len +from pytorch_lightning.utilities.apply_func import apply_to_collection + class MultiIterator(object): - def __init__(self, data) -> None: - super(object, self).__init__() - self.d = data - self.l = get_len(data) - - def __len__(self) -> int: - return get_len(self.d) + SUPPORTED_MODES = ('min_size', 'max_size_cycle') + + def __init__(self, loaders: Any, mode: str = 'min_size') -> None: + self.loaders = loaders + self.num_batches = self._calc_num_batches(loaders, mode) + + def _calc_num_batches(self, loaders, mode: str) -> Union[int, float]: + all_lengths = apply_to_collection(loaders, Iterable, get_len, + wrong_dtype=(Sequence, Mapping)) + + if mode == 'min_size': + compare_func = min + elif mode == 'max_size_cycle': + compare_func = max + else: + raise ValueError(f"Invalid Mode: {mode}") + + if isinstance(all_lengths, (int, float)): + return all_lengths + if isinstance(all_lengths, Mapping): + return compare_func(all_lengths.values()) + elif isinstance(all_lengths, Sequence): + return compare_func(all_lengths) + + raise TypeError(f'Got Type {type(all_lengths).__name__}, but expected one of Sequence, Mapping, int or float') + + def __len__(self) -> Union[int, float]: + # Return type might be int or inf. Inf will cause type error when calling len() + return self.num_batches def __iter__(self): - if isinstance(self.d, dict): - gen = {} - for k,v in self.d.items(): - gen[k] = itertools.cycle(v) - for i in range(self.l): + if isinstance(self.loaders, Mapping): + gens = {} + for batch_idx in range(self.num_batches): rv = {} - for k,v in self.d.items(): - rv[k] = next(gen[k]) + for loader_name, loader in self.loaders.items(): + # If reaching the end of the iterator, recreate one + # because shuffle=True in dataloader, the iterator will have a different order + if batch_idx % len(loader) == 0: + gens[loader_name] = iter(loader) + rv[loader_name] = next(gens[loader_name]) yield rv - else: - gen = itertools.cycle(self.d) - for i in range(self.l): - batch = next(gen) - yield batch \ No newline at end of file + elif isinstance(self.loaders, Sequence): + gens = [None] * self.num_batches + for batch_idx in range(self.num_batches): + rv = [] + for idx, loader in enumerate(self.loaders): + # If reaching the end of the iterator, recreate one + # because shuffle=True in dataloader, the iterator will have a different order + if batch_idx % len(loader) == 0: + gens[idx] = iter(loader) + rv.append(next(gens[idx])) + yield rv + + return iter(self.loaders) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c4540c41a85bb..4aecb54d27182 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -283,6 +283,11 @@ def __init__( enable_pl_optimizer: If True, each optimizer will be wrapped by `pytorch_lightning.core.optimizer.LightningOptimizer`. It allows Lightning to handle AMP, TPU, accumulated_gradients, etc.. + + multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders. + In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, + and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets + reload when reaching the minimum length of datasets. """ super().__init__() self._device_type = DeviceType.CPU diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 26ebb6a4dcfa0..831ace2b26ea0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -45,7 +45,8 @@ def __init__(self, trainer, multiple_trainloader_mode): self.automatic_optimization = True self._curr_step_result = None self._cur_grad_norm_dict = None - self._multiple_trainloader_mode + self._multiple_trainloader_mode = multiple_trainloader_mode + self.trainer._multiple_trainloader_mode = multiple_trainloader_mode def on_trainer_init( self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization @@ -542,12 +543,12 @@ def run_training_epoch(self): model = self.trainer.get_model() # modify dataloader if needed (ddp, etc...) - train_dataloader = self.trainer.accelerator_backend.process_dataloader(CombinedLoaderIterator(self.trainer.train_dataloader, mode=self._multiple_trainloader_mode)) + train_dataloader = self.trainer.accelerator_backend.process_dataloader(self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] - self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) + train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 should_check_val = False diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 54f81f20f9ab7..a997852e38e0f 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -17,6 +17,7 @@ from torch.utils.data import DataLoader, IterableDataset from pytorch_lightning.utilities import rank_zero_warn +from typing import Union def has_iterable_dataset(dataloader: DataLoader): @@ -45,3 +46,13 @@ def has_len(dataloader: DataLoader) -> bool: ' this can lead to unintended side effects since the samples will be duplicated.' ) return has_len + + +def get_len(dataloader: DataLoader) -> Union[int, float]: + """ Return the length of the given DataLoader. If __len__ method is not implemented, + return float('inf'). """ + + if has_len(dataloader): + return len(dataloader) + + return float('inf') diff --git a/tests/base/model_train_dataloaders.py b/tests/base/model_train_dataloaders.py index ad980f14fe95c..4744e5d52ec65 100644 --- a/tests/base/model_train_dataloaders.py +++ b/tests/base/model_train_dataloaders.py @@ -20,7 +20,7 @@ class TrainDataloaderVariations(ABC): @abstractmethod - def dataloader(self, train: bool): + def dataloader(self, train: bool, *args, **kwargs): """placeholder""" def train_dataloader(self): @@ -37,3 +37,8 @@ def train_dataloader__zero_length(self): dataloader.dataset.data = dataloader.dataset.data[:0] dataloader.dataset.targets = dataloader.dataset.targets[:0] return dataloader + + def train_dataloader__multiple(self): + """Return a mapping loaders with different lengths""" + return {'a': self.dataloader(train=True, num_samples=100), + 'b': self.dataloader(train=True, num_samples=50)} diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 95a4086274b35..ccc40e16892a9 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -149,3 +149,33 @@ def eval_epoch_end_full_loop_result_obj_dp(self, result): setattr(result, f'{eval_name}_step_metric', reduced) return result + + def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=None): + """Training step for multiple train loaders""" + + assert isinstance(batch, dict) + assert len(batch) == 2 + assert 'a' in batch and 'b' in batch + + # forward pass + x, y = batch['a'] + x = x.view(x.size(0), -1) + y_hat = self(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + log_val = loss_val + + # alternate between tensors and scalars for "log" and "progress_bar" + if batch_idx % 2 == 0: + log_val = log_val.item() + + output = OrderedDict( + { + 'loss': loss_val, + 'progress_bar': {'some_val': log_val * log_val}, + 'log': {'train_some_val': log_val * log_val}, + } + ) + return output + diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 318ed7bebea15..599ae862b39f1 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -656,6 +656,62 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path): trainer.test(**test_options) +@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.') +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +@patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4) +def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): + """ Test that error is raised if dataloader with only a few workers is used """ + + model = EvalModelTemplate() + model.training_step = model.training_step__multiple_dataloaders + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders + model.test_epoch_end = model.test_epoch_end__multiple_dataloaders + + # logger file to get meta + train_dl = model.dataloader(train=True) + train_dl.num_workers = 0 + + val_dl = model.dataloader(train=False) + val_dl.num_workers = 0 + + train_dl = model.dataloader(train=False) + train_dl.num_workers = 0 + + train_multi_dl = {'a': train_dl, 'b': train_dl} + val_multi_dl = [val_dl, val_dl] + test_multi_dl = [train_dl, train_dl] + + fit_options = dict(train_dataloader=train_multi_dl, + val_dataloaders=val_multi_dl) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + ) + + # fit model + with pytest.warns( + UserWarning, match='The dataloader, train dataloader, does not have many workers which may be a bottleneck.' + ): + trainer.fit(model, **fit_options) + + with pytest.warns( + UserWarning, match='The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.' + ): + trainer.fit(model, **fit_options) + + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + test_options = dict(test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) + with pytest.warns( + UserWarning, match='The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.' + ): + trainer.test(**test_options) + + @pytest.mark.xfail( LooseVersion(torch.__version__) < LooseVersion("1.4.0"), reason="IterableDataset with __len__ before 1.4 raises", @@ -855,27 +911,24 @@ def train_dataloader(self): assert 1 == result -def test_fit_multiple_train_loaders(tmpdir): - class MutipleLoaderModel(EvalModelTemplate): - def train_dataloader(self): - return {'a': super().train_dataloader(), - 'b': super().train_dataloader()} - - def training_step(self, batch, batch_idx, optimizer_idx=None): - assert isinstance(batch, dict) - assert len(batch) == 2 - assert 'a' in batch and 'b' in batch - return super().training_step(batch=batch['a'], batch_idx=batch_idx, - optimizer_idx=optimizer_idx) - - hparams = EvalModelTemplate.get_default_hparams() +@pytest.mark.parametrize(['multiple_trainloader_mode', 'num_training_batches'], [ + pytest.param("min_size", 5), + pytest.param("max_size_cycle", 10), +]) +def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches): + """Integration test for multple train loaders""" + model = EvalModelTemplate() - model = MutipleLoaderModel(**hparams) + model.train_dataloader = model.train_dataloader__multiple + model.training_step = model.training_step__multiple_dataloaders trainer = Trainer( - fast_dev_run=True, default_root_dir=tmpdir + max_epochs=1, default_root_dir=tmpdir, multiple_trainloader_mode=multiple_trainloader_mode ) + assert 1 == trainer.fit(model) + # verify the num_training_batches according to the multiple_trainloader_mode + assert num_training_batches == trainer.num_training_batches @pytest.mark.parametrize('check_interval', [1.0]) diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 2eebd5d2b0022..070997dd48e6f 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -1,11 +1,14 @@ from collections import Sequence +import pytest import torch -from pytorch_lightning.trainer.supporters import CycleIterator, CombinedLoaderIterator +from torch.utils.data import TensorDataset +from pytorch_lightning.trainer.supporters import CycleIterator, CombinedLoader, CombinedDataset, CombinedLoaderIterator def test_cycle_iterator(): + """Test the cycling function of `CycleIterator`""" iterator = CycleIterator(range(100), 1000) assert len(iterator) == 1000 for idx, item in enumerate(iterator): @@ -14,11 +17,77 @@ def test_cycle_iterator(): assert idx == len(iterator) - 1 +def test_none_length_cycle_iterator(): + """Test the infinite cycling function of `CycleIterator`""" + iterator = CycleIterator(range(100)) + assert iterator.__len__() == float('inf') + + # test infinite loop + for idx, item in enumerate(iterator): + if idx == 1000: + break + assert item == 0 + + +@pytest.mark.parametrize(['dataset_1', 'dataset_2'], [ + ([list(range(10)), list(range(20))]), + ([range(10), range(20)]), + ([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]), + ([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]) +]) +def test_combined_dataset(dataset_1, dataset_2): + """Verify the length of the CombinedDataset""" + datasets = [dataset_1, dataset_2] + combined_dataset = CombinedDataset(datasets) + + assert combined_dataset.max_len == 20 + assert combined_dataset.min_len == len(combined_dataset) == 10 + + +def test_combined_dataset_length_mode_error(): + with pytest.raises(ValueError): + CombinedDataset._calc_num_data([range(10)], 'test') + + def test_combined_loader_iterator_dict_min_size(): + """Test `CombinedLoaderIterator` given mapping loaders""" + loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), + 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + + combined_iter = CombinedLoaderIterator(loaders) + + for idx, item in enumerate(combined_iter): + assert isinstance(item, dict) + assert len(item) == 2 + assert 'a' in item and 'b' in item + + assert idx == min(len(loaders['a']), len(loaders['b'])) - 1 + + +def test_combined_loader_init_mode_error(): + """Test the ValueError when constructing `CombinedLoader`""" + with pytest.raises(ValueError): + CombinedLoader([range(10)], 'test') + + +def test_combined_loader_loader_type_error(): + """Test the ValueError when wrapping the loaders""" + with pytest.raises(ValueError): + CombinedLoader(None, 'max_size_cycle') + + +def test_combined_loader_calc_length_mode_error(): + """Test the ValueError when calculating the number of batches""" + with pytest.raises(TypeError): + CombinedLoader._calc_num_batches(None) + + +def test_combined_loader_dict_min_size(): + """Test `CombinedLoader` of mode 'min_size' given mapping loaders""" loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} - combined_loader = CombinedLoaderIterator(loaders, 'min_size') + combined_loader = CombinedLoader(loaders, 'min_size') assert len(combined_loader) == min([len(v) for v in loaders.values()]) @@ -30,11 +99,12 @@ def test_combined_loader_iterator_dict_min_size(): assert idx == len(combined_loader) - 1 -def test_combined_loader_iterator_dict_max_size_cycle(): +def test_combined_loader_dict_max_size_cycle(): + """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders""" loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} - combined_loader = CombinedLoaderIterator(loaders, 'max_size_cycle') + combined_loader = CombinedLoader(loaders, 'max_size_cycle') assert len(combined_loader) == max([len(v) for v in loaders.values()]) @@ -46,11 +116,12 @@ def test_combined_loader_iterator_dict_max_size_cycle(): assert idx == len(combined_loader) - 1 -def test_combined_loader_iterator_sequence_min_size(): +def test_combined_loader_sequence_min_size(): + """Test `CombinedLoader` of mode 'min_size' given sequence loaders""" loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), torch.utils.data.DataLoader(range(20), batch_size=5)] - combined_loader = CombinedLoaderIterator(loaders, 'min_size') + combined_loader = CombinedLoader(loaders, 'min_size') assert len(combined_loader) == min([len(v) for v in loaders]) @@ -61,11 +132,12 @@ def test_combined_loader_iterator_sequence_min_size(): assert idx == len(combined_loader) - 1 -def test_combined_loader_iterator_sequence_max_size_cycle(): +def test_combined_loader_sequence_max_size_cycle(): + """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders""" loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), torch.utils.data.DataLoader(range(20), batch_size=5)] - combined_loader = CombinedLoaderIterator(loaders, 'max_size_cycle') + combined_loader = CombinedLoader(loaders, 'max_size_cycle') assert len(combined_loader) == max([len(v) for v in loaders]) diff --git a/tests/trainer/test_train_loader_patch.py b/tests/trainer/test_train_loader_patch.py new file mode 100644 index 0000000000000..485b40580534f --- /dev/null +++ b/tests/trainer/test_train_loader_patch.py @@ -0,0 +1,69 @@ +from collections import Sequence + +import pytest +import torch + +from torch.utils.data import TensorDataset +from pytorch_lightning.trainer.train_loader_patch import MultiIterator + + +def test_multi_iterator_dict_min_size(): + loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), + 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + + combined_loader = MultiIterator(loaders, 'min_size') + + assert len(combined_loader) == min([len(v) for v in loaders.values()]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, dict) + assert len(item) == 2 + assert 'a' in item and 'b' in item + + assert idx == len(combined_loader) - 1 + + +def test_multi_iterator_dict_max_size_cycle(): + loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), + 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + + combined_loader = MultiIterator(loaders, 'max_size_cycle') + + assert len(combined_loader) == max([len(v) for v in loaders.values()]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, dict) + assert len(item) == 2 + assert 'a' in item and 'b' in item + + assert idx == len(combined_loader) - 1 + + +def test_multi_iterator_sequence_min_size(): + loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), + torch.utils.data.DataLoader(range(20), batch_size=5)] + + combined_loader = MultiIterator(loaders, 'min_size') + + assert len(combined_loader) == min([len(v) for v in loaders]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, Sequence) + assert len(item) == 2 + + assert idx == len(combined_loader) - 1 + + +def test_multi_iterator_sequence_max_size_cycle(): + loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), + torch.utils.data.DataLoader(range(20), batch_size=5)] + + combined_loader = MultiIterator(loaders, 'max_size_cycle') + + assert len(combined_loader) == max([len(v) for v in loaders]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, Sequence) + assert len(item) == 2 + + assert idx == len(combined_loader) - 1 From 32aeb70079e0aedd5ab0ae25e868f37218a9ab27 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Tue, 8 Dec 2020 16:27:22 +0100 Subject: [PATCH 15/37] pep8 --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 831ace2b26ea0..67365b66f99a2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -23,7 +23,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum +from pytorch_lightning.trainer.supporters import Accumulator, CombinedLoaderIterator, TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, parsing from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException From c4482c4dc96a43584fd7c4c75f3dbd9f61dce442 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Tue, 8 Dec 2020 16:42:00 +0100 Subject: [PATCH 16/37] Update train_loader_patch.py --- pytorch_lightning/trainer/train_loader_patch.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/pytorch_lightning/trainer/train_loader_patch.py b/pytorch_lightning/trainer/train_loader_patch.py index 60de3aaa4d9ff..34454059f1632 100644 --- a/pytorch_lightning/trainer/train_loader_patch.py +++ b/pytorch_lightning/trainer/train_loader_patch.py @@ -1,18 +1,3 @@ -''' - This patch solves two problems discussed in - https://github.com/PyTorchLightning/pytorch-lightning/pull/1959 - - The function train_dataloader can either return a single instance of - torch.utils.data.DataLoader or a dictionary of dataloaders. - - This patch fixes the length and iteration issus - and make the rest of the code oblivious of the underlying data structure. - - I will keep the name of the class but a better name is probable advisable - - @christofer-f -''' - import itertools from typing import Any, Union From 1a531a5d5c5da1ff29a46b8e02f9d5edb875770b Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Tue, 8 Dec 2020 16:42:48 +0100 Subject: [PATCH 17/37] Apply suggestions from code review Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/supporters.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index b48ec65952692..234933d064597 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -187,7 +187,7 @@ class CycleIterator(object): """ Iterator for restarting a dataloader if it runs out of samples """ - def __init__(self, loader: Any, length: int = None): + def __init__(self, loader: Any, length: Optional[int] = None): """ Args: @@ -318,8 +318,8 @@ class CombinedLoader(object): {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])} >>> combined_loader = CombinedLoader(loaders, 'min_size') - >>> for item in combined_loader: \ - print(item) + >>> for item in combined_loader: + ... print(item) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} @@ -479,8 +479,9 @@ def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any: return apply_to_collection(loader_iters, Iterator, next) @staticmethod - def create_loader_iters(loaders: Union[Any, Iterator, - Sequence, Mapping]) -> Union[Any, Iterator, Sequence, Mapping]: + def create_loader_iters( + loaders: Union[Any, Iterator, Sequence, Mapping] + ) -> Union[Any, Iterator, Sequence, Mapping]: """ Create and return a collection of iterators from loaders. From a5d3652adccb2a608d19863762f2948a7dfbec64 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Tue, 8 Dec 2020 16:43:55 +0100 Subject: [PATCH 18/37] Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/supporters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 234933d064597..ad0b7ff3d1f33 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -243,7 +243,7 @@ def __next__(self) -> Any: finally: self.counter += 1 - def __len__(self) -> int: + def __len__(self) -> Union[int, float]: return self.length From dc94a59126562fc50c8e6902d778d7ac1be4c0ff Mon Sep 17 00:00:00 2001 From: justusschock Date: Wed, 9 Dec 2020 15:12:20 +0100 Subject: [PATCH 19/37] reviewer comments --- pytorch_lightning/trainer/data_loading.py | 2 -- pytorch_lightning/trainer/supporters.py | 4 ++-- tests/trainer/test_supporters.py | 8 ++++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 21768808215bb..54e0cab31d980 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -168,8 +168,6 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode) - self.num_training_batches = 0 - self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index ad0b7ff3d1f33..02501a88800da 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -309,10 +309,10 @@ class CombinedLoader(object): Combines different dataloaders and allows sampling in parallel. Examples: - >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), \ + >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), 'b': torch.utils.data.DataLoader(range(15), batch_size=5)} >>> combined_loader = CombinedLoader(loaders, 'max_size_cycle') - >>> for item in combined_loader: \ + >>> for item in combined_loader: print(item) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 070997dd48e6f..88812f01b1a22 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -45,7 +45,7 @@ def test_combined_dataset(dataset_1, dataset_2): def test_combined_dataset_length_mode_error(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='Invalid Mode'): CombinedDataset._calc_num_data([range(10)], 'test') @@ -66,19 +66,19 @@ def test_combined_loader_iterator_dict_min_size(): def test_combined_loader_init_mode_error(): """Test the ValueError when constructing `CombinedLoader`""" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='Invalid Mode'): CombinedLoader([range(10)], 'test') def test_combined_loader_loader_type_error(): """Test the ValueError when wrapping the loaders""" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='Invalid Datatype'): CombinedLoader(None, 'max_size_cycle') def test_combined_loader_calc_length_mode_error(): """Test the ValueError when calculating the number of batches""" - with pytest.raises(TypeError): + with pytest.raises(TypeError, match='Got Type NoneType, but expected one of Sequence, int or Mapping'): CombinedLoader._calc_num_batches(None) From 3c85cfb48fd0ad4d7194e2fcf3fee749876a7ae7 Mon Sep 17 00:00:00 2001 From: justusschock Date: Wed, 9 Dec 2020 16:13:51 +0100 Subject: [PATCH 20/37] fix stupid import --- pytorch_lightning/trainer/data_loading.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 54e0cab31d980..42ce98dbbb45d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -33,11 +33,9 @@ from copy import deepcopy from typing import Iterable from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.trainer.train_loader_patch import MultiIterator -from pytorch_lightning.trainer.supporters import CombinedLoaderIterator, CombinedLoader +from pytorch_lightning.trainer.supporters import CombinedLoader TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() -from pytorch_lightning.trainer.train_loader_patch import MagicClass class TrainerDataLoadingMixin(ABC): From c091414fc3ec8cd00f495034f31a0f1f97a3ff95 Mon Sep 17 00:00:00 2001 From: justusschock Date: Wed, 9 Dec 2020 16:32:32 +0100 Subject: [PATCH 21/37] add docs --- docs/source/multiple_loaders.rst | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/docs/source/multiple_loaders.rst b/docs/source/multiple_loaders.rst index ee7b32555c53f..f7746479e9672 100644 --- a/docs/source/multiple_loaders.rst +++ b/docs/source/multiple_loaders.rst @@ -9,14 +9,16 @@ Multiple Datasets Lightning supports multiple dataloaders in a few ways. 1. Create a dataloader that iterates multiple datasets under the hood. -2. In the validation and test loop you also have the option to return multiple dataloaders +2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning + will automatically combine the batches from different loaders. +3. In the validation and test loop you also have the option to return multiple dataloaders which lightning will call sequentially. ---------- Multiple training dataloaders ----------------------------- -For training, the best way to use multiple dataloaders is to create a ``DataLoader`` class +For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class which wraps your multiple dataloaders (this of course also works for testing and validation dataloaders). @@ -59,6 +61,31 @@ dataloaders). # SAME ... +However, with lightning you can also return multiple loaders and lightning will take care of batch combination. + +For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer.Trainer.multiple_trainloader_mode` + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + + loader_a = torch.utils.data.DataLoader(range(6), batch_size=4) + loader_b = torch.utils.data.DataLoader(range(15), batch_size=5) + + # pass loaders as a dict. This will create batches like this: + # {'a': batch from loader_a, 'b': batch from loader_b} + loaders = {'a': loader_a, + 'b': loader_b} + + # OR: + # pass loaders as sequence. This will create batches like this: + # [batch from loader_a, batch from loader_b] + loaders = [loader_a, loader_b] + + return loaders + ---------- Test/Val dataloaders From ebb62772d7294e67cce754046493897efbcca3b9 Mon Sep 17 00:00:00 2001 From: justusschock Date: Wed, 9 Dec 2020 16:41:59 +0100 Subject: [PATCH 22/37] add back line separator --- pytorch_lightning/trainer/supporters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 02501a88800da..ad0b7ff3d1f33 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -309,10 +309,10 @@ class CombinedLoader(object): Combines different dataloaders and allows sampling in parallel. Examples: - >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), + >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), \ 'b': torch.utils.data.DataLoader(range(15), batch_size=5)} >>> combined_loader = CombinedLoader(loaders, 'max_size_cycle') - >>> for item in combined_loader: + >>> for item in combined_loader: \ print(item) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} From e4e50abb23588ca9da4045182763c3815162f7ab Mon Sep 17 00:00:00 2001 From: justusschock Date: Wed, 9 Dec 2020 16:43:38 +0100 Subject: [PATCH 23/37] fix line sep --- pytorch_lightning/trainer/supporters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index ad0b7ff3d1f33..665d74f6638e5 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -309,11 +309,11 @@ class CombinedLoader(object): Combines different dataloaders and allows sampling in parallel. Examples: - >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), \ - 'b': torch.utils.data.DataLoader(range(15), batch_size=5)} + >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), + ... 'b': torch.utils.data.DataLoader(range(15), batch_size=5)} >>> combined_loader = CombinedLoader(loaders, 'max_size_cycle') - >>> for item in combined_loader: \ - print(item) + >>> for item in combined_loader: + ... print(item) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])} From 0227a68d1cb101543f852b8c485c8b70a7406730 Mon Sep 17 00:00:00 2001 From: justusschock Date: Wed, 9 Dec 2020 16:51:12 +0100 Subject: [PATCH 24/37] pep8 --- pytorch_lightning/trainer/data_loading.py | 2 -- pytorch_lightning/trainer/training_loop.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 42ce98dbbb45d..a0a876c2a5aaa 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -30,8 +30,6 @@ from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils -from copy import deepcopy -from typing import Iterable from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.trainer.supporters import CombinedLoader diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 67365b66f99a2..9c2d9391425ac 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -31,6 +31,7 @@ from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.warning_utils import WarningCache +from pytorch_lightning.utilities import TPU_AVAILABLE class TrainLoop: From 644a49003e453635b89a6cd1f28ec35173aedfcb Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 21 Dec 2020 11:11:49 +0100 Subject: [PATCH 25/37] Apply suggestions from code review --- docs/source/multiple_loaders.rst | 2 +- pytorch_lightning/trainer/data_loading.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/source/multiple_loaders.rst b/docs/source/multiple_loaders.rst index f7746479e9672..fb1aa33f80462 100644 --- a/docs/source/multiple_loaders.rst +++ b/docs/source/multiple_loaders.rst @@ -77,7 +77,7 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer # pass loaders as a dict. This will create batches like this: # {'a': batch from loader_a, 'b': batch from loader_b} loaders = {'a': loader_a, - 'b': loader_b} + 'b': loader_b} # OR: # pass loaders as sequence. This will create batches like this: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index a0a876c2a5aaa..bfecb86454ab0 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -29,12 +29,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_utils import is_overridden -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils +from pytorch_lightning.utilities.xla_device_utils import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.trainer.supporters import CombinedLoader -TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() - class TrainerDataLoadingMixin(ABC): From 552e6a6bd6f0f5cc46b6106bc34c5a244c490695 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 21 Dec 2020 15:48:59 +0530 Subject: [PATCH 26/37] fix --- pytorch_lightning/trainer/data_loading.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index bfecb86454ab0..16ace48155d4d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -29,7 +29,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_utils import is_overridden -from pytorch_lightning.utilities.xla_device_utils import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.trainer.supporters import CombinedLoader From 3ab1907859056d915193508785c1be8e0033d1e0 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 21 Dec 2020 15:55:54 +0530 Subject: [PATCH 27/37] fix --- pytorch_lightning/trainer/training_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9c2d9391425ac..67365b66f99a2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -31,7 +31,6 @@ from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.warning_utils import WarningCache -from pytorch_lightning.utilities import TPU_AVAILABLE class TrainLoop: From a2d017ff54d494062d07de0dcf9374d75d56b302 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 31 Dec 2020 10:42:00 +0100 Subject: [PATCH 28/37] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/supporters.py | 8 ++++---- pytorch_lightning/trainer/train_loader_patch.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 665d74f6638e5..3d1ee165bbc07 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -272,11 +272,11 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str = 'min') -> Uni Args: datasets: a sequence/mapping datasets. Can be a collections of torch.utils.data.Dataset, Iterable or even None. - mode (str): 'min' or 'max'. Determine `CombinedDataset`'s length is the maximum or minimum of + mode: 'min' or 'max'. Determine `CombinedDataset`'s length is the maximum or minimum of the datasets. Returns: - length (int): the length of `CombinedDataset` + length: the length of `CombinedDataset` """ if mode not in ['min', 'max']: @@ -406,7 +406,7 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]: loaders: a collections of loaders. Returns: - length (int): the minimum length of loaders + length: the minimum length of loaders """ all_lengths = apply_to_collection(loaders, Iterable, get_len, @@ -486,7 +486,7 @@ def create_loader_iters( Create and return a collection of iterators from loaders. Args: - loaderss: a collections of loaders + loaders: a collections of loaders Returns a collections of iterators diff --git a/pytorch_lightning/trainer/train_loader_patch.py b/pytorch_lightning/trainer/train_loader_patch.py index 34454059f1632..247497c102e32 100644 --- a/pytorch_lightning/trainer/train_loader_patch.py +++ b/pytorch_lightning/trainer/train_loader_patch.py @@ -25,7 +25,8 @@ def _calc_num_batches(self, loaders, mode: str) -> Union[int, float]: elif mode == 'max_size_cycle': compare_func = max else: - raise ValueError(f"Invalid Mode: {mode}") + raise ValueError(f"Invalid Mode: {mode}. Supported modes are: {", ".join(self.SUPPORTED_MODES)}") + if isinstance(all_lengths, (int, float)): return all_lengths From 6af5c90012fef52a6da28c5fdd2c8c2e01f1af1d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 31 Dec 2020 10:53:56 +0100 Subject: [PATCH 29/37] Apply suggestions from code review Co-authored-by: Nicki Skafte --- pytorch_lightning/trainer/train_loader_patch.py | 15 +++++++++++++++ pytorch_lightning/utilities/data.py | 3 +-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/train_loader_patch.py b/pytorch_lightning/trainer/train_loader_patch.py index 247497c102e32..8a77ee0fac6d5 100644 --- a/pytorch_lightning/trainer/train_loader_patch.py +++ b/pytorch_lightning/trainer/train_loader_patch.py @@ -1,5 +1,20 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import itertools + from typing import Any, Union from collections.abc import Iterable, Iterator, Mapping, Sequence diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index a997852e38e0f..1b4907ab8c2d4 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -49,8 +49,7 @@ def has_len(dataloader: DataLoader) -> bool: def get_len(dataloader: DataLoader) -> Union[int, float]: - """ Return the length of the given DataLoader. If __len__ method is not implemented, - return float('inf'). """ + """ Return the length of the given DataLoader. If ``__len__`` method is not implemented, return float('inf'). """ if has_len(dataloader): return len(dataloader) From e138c7adedfd9d445f71ecd24ae4e01d38068c7f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 31 Dec 2020 10:58:48 +0100 Subject: [PATCH 30/37] flake8 --- pytorch_lightning/trainer/train_loader_patch.py | 14 +++++--------- pytorch_lightning/trainer/training_loop.py | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/train_loader_patch.py b/pytorch_lightning/trainer/train_loader_patch.py index 8a77ee0fac6d5..26333064e9c3b 100644 --- a/pytorch_lightning/trainer/train_loader_patch.py +++ b/pytorch_lightning/trainer/train_loader_patch.py @@ -12,13 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools - from typing import Any, Union -from collections.abc import Iterable, Iterator, Mapping, Sequence - -from torch.utils.data import DataLoader +from collections.abc import Iterable, Mapping, Sequence from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -32,16 +28,16 @@ def __init__(self, loaders: Any, mode: str = 'min_size') -> None: self.num_batches = self._calc_num_batches(loaders, mode) def _calc_num_batches(self, loaders, mode: str) -> Union[int, float]: - all_lengths = apply_to_collection(loaders, Iterable, get_len, - wrong_dtype=(Sequence, Mapping)) + all_lengths = apply_to_collection( + loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping) + ) if mode == 'min_size': compare_func = min elif mode == 'max_size_cycle': compare_func = max else: - raise ValueError(f"Invalid Mode: {mode}. Supported modes are: {", ".join(self.SUPPORTED_MODES)}") - + raise ValueError(f"Invalid Mode: {mode}. Supported modes are: {', '.join(self.SUPPORTED_MODES)}") if isinstance(all_lengths, (int, float)): return all_lengths diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 67365b66f99a2..831ace2b26ea0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -23,7 +23,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.trainer.supporters import Accumulator, CombinedLoaderIterator, TensorRunningAccum +from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, parsing from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException From 9265651989b5ce9e4050db52c9630af4b176b25c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 31 Dec 2020 11:24:44 +0100 Subject: [PATCH 31/37] chlog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index edc319511b195..5a779a2324c12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Add Support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959)) + ### Changed From 7669a40625eb77eb747db0da731d163c545ea7ea Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 4 Jan 2021 10:09:57 +0100 Subject: [PATCH 32/37] Update pytorch_lightning/trainer/supporters.py Co-authored-by: chaton --- pytorch_lightning/trainer/supporters.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 3d1ee165bbc07..fca028d72dba3 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -251,6 +251,7 @@ class CombinedDataset(object): """ Combine multiple datasets and compute their statistics """ + MODES = ['min', 'max'] def __init__(self, datasets: Union[Sequence, Mapping]): """ From 8e9bd3d0858c41ddcf40a91e03226c76c778f102 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 4 Jan 2021 10:44:18 +0100 Subject: [PATCH 33/37] add missing test --- tests/base/model_train_dataloaders.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/base/model_train_dataloaders.py b/tests/base/model_train_dataloaders.py index 4744e5d52ec65..65873cfa8d6c4 100644 --- a/tests/base/model_train_dataloaders.py +++ b/tests/base/model_train_dataloaders.py @@ -38,7 +38,11 @@ def train_dataloader__zero_length(self): dataloader.dataset.targets = dataloader.dataset.targets[:0] return dataloader - def train_dataloader__multiple(self): + def train_dataloader__multiple_mapping(self): """Return a mapping loaders with different lengths""" return {'a': self.dataloader(train=True, num_samples=100), 'b': self.dataloader(train=True, num_samples=50)} + + def train_dataloader__multiple_sequence(self): + return [self.dataloader(train=True, num_samples=100), + self.dataloader(train=True, num_samples=50)] From 854756b8613ed26be00c1c28b16b76f9a5d29711 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 4 Jan 2021 10:49:20 +0100 Subject: [PATCH 34/37] fix dataset length --- pytorch_lightning/trainer/supporters.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index fca028d72dba3..ae039d07ed7d9 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -251,43 +251,42 @@ class CombinedDataset(object): """ Combine multiple datasets and compute their statistics """ - MODES = ['min', 'max'] - def __init__(self, datasets: Union[Sequence, Mapping]): + def __init__(self, datasets: Union[Sequence, Mapping], mode: str): """ Args: datasets: a sequence/mapping datasets. Can be a collections of torch.utils.Dataset, Iterable or even None. + mode: whether to use the minimum number of batches in all samples or the maximum + number of batches in all samples. """ self.datasets = datasets - - self.max_len = self._calc_num_data(self.datasets, 'max') - self.min_len = self._calc_num_data(self.datasets, 'min') + self.mode = mode @staticmethod - def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str = 'min') -> Union[int, float]: + def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: """ Compute the length of `CombinedDataset` according to the `mode`. Args: datasets: a sequence/mapping datasets. Can be a collections of torch.utils.data.Dataset, Iterable or even None. - mode: 'min' or 'max'. Determine `CombinedDataset`'s length is the maximum or minimum of + mode: Determine `CombinedDataset`'s length is the maximum or minimum of the datasets. Returns: length: the length of `CombinedDataset` """ - if mode not in ['min', 'max']: + if mode not in ['min_size', 'max_size_cycle']: raise ValueError(f"Invalid Mode: {mode}") # extract the lengths all_lengths = apply_to_collection(datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping)) - compute_func = eval(mode) + compute_func = {'min_size': min, 'max_size_cycle': max} if isinstance(all_lengths, (int, float)): length = all_lengths @@ -302,7 +301,7 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str = 'min') -> Uni def __len__(self) -> int: """Return the minimum length of the datasets.""" - return self.min_len + return self._calc_num_data(self.datasets, self.mode) class CombinedLoader(object): @@ -341,7 +340,7 @@ def __init__(self, loaders: Any, mode: str = 'min_size'): datasets = apply_to_collection(self.loaders, Iterable, getattr, 'dataset', None, wrong_dtype=(Sequence, Mapping)) # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader - self.dataset = CombinedDataset(datasets) + self.dataset = CombinedDataset(datasets, mode) if mode not in self.SUPPORTED_MODES: raise ValueError(f"Invalid Mode: {mode}") From 55037d4ebad2dc4fbd78f5bb0d21e1d461b25896 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 4 Jan 2021 10:49:51 +0100 Subject: [PATCH 35/37] Update supporters.py --- pytorch_lightning/trainer/supporters.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index ae039d07ed7d9..81d4f0cfcbcf1 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -307,6 +307,11 @@ def __len__(self) -> int: class CombinedLoader(object): """ Combines different dataloaders and allows sampling in parallel. + + Supported modes are 'min_size', which raises StopIteration after the shortest loader + (the one with the lowest number of batches) is done, and 'max_size_cycle` which raises + StopIteration after the longest loader (the one with most batches) is done, while cycling + through the shorter loaders. Examples: >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), From 9d42c7ec57647dd714d54f50284b8c3e10d1baad Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 4 Jan 2021 10:50:20 +0100 Subject: [PATCH 36/37] remove unused patch --- .../trainer/train_loader_patch.py | 79 ------------------- 1 file changed, 79 deletions(-) delete mode 100644 pytorch_lightning/trainer/train_loader_patch.py diff --git a/pytorch_lightning/trainer/train_loader_patch.py b/pytorch_lightning/trainer/train_loader_patch.py deleted file mode 100644 index 26333064e9c3b..0000000000000 --- a/pytorch_lightning/trainer/train_loader_patch.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Any, Union -from collections.abc import Iterable, Mapping, Sequence - -from pytorch_lightning.utilities.data import get_len -from pytorch_lightning.utilities.apply_func import apply_to_collection - - -class MultiIterator(object): - SUPPORTED_MODES = ('min_size', 'max_size_cycle') - - def __init__(self, loaders: Any, mode: str = 'min_size') -> None: - self.loaders = loaders - self.num_batches = self._calc_num_batches(loaders, mode) - - def _calc_num_batches(self, loaders, mode: str) -> Union[int, float]: - all_lengths = apply_to_collection( - loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping) - ) - - if mode == 'min_size': - compare_func = min - elif mode == 'max_size_cycle': - compare_func = max - else: - raise ValueError(f"Invalid Mode: {mode}. Supported modes are: {', '.join(self.SUPPORTED_MODES)}") - - if isinstance(all_lengths, (int, float)): - return all_lengths - if isinstance(all_lengths, Mapping): - return compare_func(all_lengths.values()) - elif isinstance(all_lengths, Sequence): - return compare_func(all_lengths) - - raise TypeError(f'Got Type {type(all_lengths).__name__}, but expected one of Sequence, Mapping, int or float') - - def __len__(self) -> Union[int, float]: - # Return type might be int or inf. Inf will cause type error when calling len() - return self.num_batches - - def __iter__(self): - if isinstance(self.loaders, Mapping): - gens = {} - for batch_idx in range(self.num_batches): - rv = {} - for loader_name, loader in self.loaders.items(): - # If reaching the end of the iterator, recreate one - # because shuffle=True in dataloader, the iterator will have a different order - if batch_idx % len(loader) == 0: - gens[loader_name] = iter(loader) - rv[loader_name] = next(gens[loader_name]) - yield rv - elif isinstance(self.loaders, Sequence): - gens = [None] * self.num_batches - for batch_idx in range(self.num_batches): - rv = [] - for idx, loader in enumerate(self.loaders): - # If reaching the end of the iterator, recreate one - # because shuffle=True in dataloader, the iterator will have a different order - if batch_idx % len(loader) == 0: - gens[idx] = iter(loader) - rv.append(next(gens[idx])) - yield rv - - return iter(self.loaders) From 232c7ce32676931843645c26dfae820961fa814d Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 4 Jan 2021 10:50:55 +0100 Subject: [PATCH 37/37] remove tests of otherwise unused patch --- tests/trainer/test_train_loader_patch.py | 69 ------------------------ 1 file changed, 69 deletions(-) delete mode 100644 tests/trainer/test_train_loader_patch.py diff --git a/tests/trainer/test_train_loader_patch.py b/tests/trainer/test_train_loader_patch.py deleted file mode 100644 index 485b40580534f..0000000000000 --- a/tests/trainer/test_train_loader_patch.py +++ /dev/null @@ -1,69 +0,0 @@ -from collections import Sequence - -import pytest -import torch - -from torch.utils.data import TensorDataset -from pytorch_lightning.trainer.train_loader_patch import MultiIterator - - -def test_multi_iterator_dict_min_size(): - loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), - 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} - - combined_loader = MultiIterator(loaders, 'min_size') - - assert len(combined_loader) == min([len(v) for v in loaders.values()]) - - for idx, item in enumerate(combined_loader): - assert isinstance(item, dict) - assert len(item) == 2 - assert 'a' in item and 'b' in item - - assert idx == len(combined_loader) - 1 - - -def test_multi_iterator_dict_max_size_cycle(): - loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), - 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} - - combined_loader = MultiIterator(loaders, 'max_size_cycle') - - assert len(combined_loader) == max([len(v) for v in loaders.values()]) - - for idx, item in enumerate(combined_loader): - assert isinstance(item, dict) - assert len(item) == 2 - assert 'a' in item and 'b' in item - - assert idx == len(combined_loader) - 1 - - -def test_multi_iterator_sequence_min_size(): - loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), - torch.utils.data.DataLoader(range(20), batch_size=5)] - - combined_loader = MultiIterator(loaders, 'min_size') - - assert len(combined_loader) == min([len(v) for v in loaders]) - - for idx, item in enumerate(combined_loader): - assert isinstance(item, Sequence) - assert len(item) == 2 - - assert idx == len(combined_loader) - 1 - - -def test_multi_iterator_sequence_max_size_cycle(): - loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), - torch.utils.data.DataLoader(range(20), batch_size=5)] - - combined_loader = MultiIterator(loaders, 'max_size_cycle') - - assert len(combined_loader) == max([len(v) for v in loaders]) - - for idx, item in enumerate(combined_loader): - assert isinstance(item, Sequence) - assert len(item) == 2 - - assert idx == len(combined_loader) - 1