From 8d8140de08dc8462937967c4b43d6dca36ac3611 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 17:39:39 +0200 Subject: [PATCH 01/19] Update supporters.py --- pytorch_lightning/trainer/supporters.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 18a012da54760..91ba316d37348 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -356,6 +356,7 @@ def sampler(self) -> Union[Iterable, Sequence, Mapping]: """Return a collections of samplers extracting from loaders.""" return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, 'sampler', None) + def _wrap_loaders_max_size_cycle(self) -> Any: """ Wraps all loaders to make sure they are cycled until the longest loader is exhausted @@ -366,27 +367,11 @@ def _wrap_loaders_max_size_cycle(self) -> Any: """ all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) - if isinstance(all_lengths, (int, float)): - length = all_lengths - - elif isinstance(all_lengths, Mapping): - length = max(all_lengths.values()) + length = _nested_calc_num_data(all_lengths, max) - elif isinstance(all_lengths, Sequence): - length = max(all_lengths) - - if isinstance(self.loaders, Mapping): - self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) for k, v in self.loaders.items()}) - - elif isinstance(self.loaders, Sequence): - self.loaders = type(self.loaders)([CycleIterator(v, length=length) for v in self.loaders]) - - # dataloaders are iterable but not sequence - elif isinstance(self.loaders, Iterable): - # only one dataloader, just keep it the same. - pass - else: - raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}') + # multiple loaders + if isinstance(self.loaders, (Sequence, Mapping)): + self.loaders = apply_to_collection(self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping)) def __iter__(self) -> Any: """ From e6df73c2815e7527ced7f050ba941930817cf687 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 17:40:00 +0200 Subject: [PATCH 02/19] Update apply_func.py --- pytorch_lightning/utilities/apply_func.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index e100a803bcd00..1b9996f3166c4 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -85,13 +85,13 @@ def apply_to_collection( # Recursively apply to collection items if isinstance(data, Mapping): - return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()}) + return elem_type({k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for k, v in data.items()}) if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple - return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) + return elem_type(*(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data)) if isinstance(data, Sequence) and not isinstance(data, str): - return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) + return elem_type([apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data]) # data is neither of dtype, nor a collection return data From a5cd8ae487bbb5117eac7ac0c50e1c81f18e88c4 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 18:00:17 +0200 Subject: [PATCH 03/19] Update supporters.py --- pytorch_lightning/trainer/supporters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 91ba316d37348..d2e8b79f9ea81 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -475,7 +475,7 @@ def create_loader_iters( def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable): - if isinstance(data, int): + if isinstance(data, (float, int)): return data if isinstance(data, Mapping): From fb3d4b3694bd4bb52d5e6f440e0f2e6ed1edd61d Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 18:01:10 +0200 Subject: [PATCH 04/19] Update model_train_dataloaders.py --- tests/base/model_train_dataloaders.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/base/model_train_dataloaders.py b/tests/base/model_train_dataloaders.py index 50c85ddc3f79d..4869f300231b3 100644 --- a/tests/base/model_train_dataloaders.py +++ b/tests/base/model_train_dataloaders.py @@ -39,7 +39,19 @@ def train_dataloader__zero_length(self): def train_dataloader__multiple_mapping(self): """Return a mapping loaders with different lengths""" - return { - 'a': self.dataloader(train=True, num_samples=100), - 'b': self.dataloader(train=True, num_samples=50), - } + + # List[DataLoader] + loaders_a_b = [ + self.dataloader(num_samples=100, train=True), + self.dataloader(num_samples=50, train=True) + ] + loaders_c_d_e = [ + self.dataloader(num_samples=50, train=True), + self.dataloader(num_samples=50, train=True), + self.dataloader(num_samples=50, train=True) + ] + + + # Dict[str, List[DataLoader]] + loaders = {"a_b": loaders_a_b, "c_d_e": loaders_c_d_e} + return loaders From cee615972a98c320fb2fb77fc47eb613dbe4ece1 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 18:01:37 +0200 Subject: [PATCH 05/19] Update model_train_steps.py --- tests/base/model_train_steps.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 9e26e9fc93bae..12c60ae1ae138 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -11,15 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from abc import ABC from collections import OrderedDict +import torch + class TrainingStepVariations(ABC): """ Houses all variations of training steps """ + test_step_inf_loss = float('inf') + def training_step(self, batch, batch_idx, optimizer_idx=None): """Lightning calls this inside the training loop""" self.training_step_called = True @@ -44,15 +49,27 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): }) return output + def training_step__inf_loss(self, batch, batch_idx, optimizer_idx=None): + output = self.training_step(batch, batch_idx, optimizer_idx) + if batch_idx == self.test_step_inf_loss: + if isinstance(output, dict): + output['loss'] *= torch.tensor(math.inf) # make loss infinite + else: + output /= 0 + return output + def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=None): """Training step for multiple train loaders""" assert isinstance(batch, dict) assert len(batch) == 2 - assert 'a' in batch and 'b' in batch + + assert 'a_b' in batch and 'c_d_e' in batch, batch.keys() + assert isinstance(batch['a_b'], list) and len(batch['a_b']) == 2 + assert isinstance(batch['c_d_e'], list) and len(batch['c_d_e']) == 3 # forward pass - x, y = batch['a'] + x, y = batch['a_b'][0] x = x.view(x.size(0), -1) y_hat = self(x) From 6095966d61a93600de6f2bfdd7163e3c0821aa84 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 18:02:01 +0200 Subject: [PATCH 06/19] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 45 ++++++++++++++++--------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d988943c06088..eb0235a0d750b 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -26,6 +26,7 @@ import tests.helpers.pipelines as tpipes from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6 from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -124,7 +125,7 @@ def test_multiple_val_dataloader(tmpdir): trainer.fit(model) # verify training completed - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # verify there are 2 val loaders assert len(trainer.val_dataloaders) == 2, 'Multiple val_dataloaders not initiated properly' @@ -193,7 +194,7 @@ def test_train_dataloader_passed_to_fit(tmpdir): fit_options = dict(train_dataloader=model.dataloader(train=True)) trainer.fit(model, **fit_options) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) @@ -219,7 +220,7 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): ) trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert len(trainer.val_dataloaders) == n if ckpt_path == 'specific': @@ -305,7 +306,7 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, assert sum(1 for _ in dl) == num_batches trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf')) assert epoch_cb.train_epoch_count == int(limit_train_batches > 0) assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf')) @@ -344,7 +345,7 @@ def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batch val_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.num_training_batches == limit_train_batches assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) assert epoch_cb.train_batches_seen == limit_train_batches * epochs @@ -388,7 +389,7 @@ def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches): val_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.num_val_batches[0] == limit_val_batches assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0) assert epoch_cb.val_batches_seen == limit_val_batches * epochs @@ -427,7 +428,7 @@ def test_datasets_dataloaders_with_limit_num_batches( test_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches[0] == limit_val_batches assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) @@ -610,12 +611,12 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path): # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) @@ -673,7 +674,7 @@ def test_inf_train_dataloader(tmpdir, check_interval): ) trainer.fit(model) # verify training completed - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @pytest.mark.parametrize('check_interval', [1.0]) @@ -692,7 +693,7 @@ def test_inf_val_dataloader(tmpdir, check_interval): trainer.fit(model) # verify training completed - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" def test_error_on_zero_len_dataloader(tmpdir): @@ -766,7 +767,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): train_dl = model.dataloader(train=False) train_dl.num_workers = 0 - train_multi_dl = {'a': train_dl, 'b': train_dl} + train_multi_dl = {'a_b': [train_dl, train_dl], 'c_d_e': [train_dl, train_dl, train_dl]} val_multi_dl = [val_dl, val_dl] test_multi_dl = [train_dl, train_dl] @@ -1098,7 +1099,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir): replace_sampler_ddp=True, ) trainer.fit(model) - assert trainer.state.finished, "DDP Training failed" + assert trainer.state == TrainerState.FINISHED, "DDP Training failed" @RunIf(min_gpus=3) @@ -1150,7 +1151,7 @@ def train_dataloader(self): # we expect the reduction for the metrics also to happen on the last batch # where we will get fewer metrics than gpus trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @pytest.mark.parametrize(['multiple_trainloader_mode', 'num_training_batches'], [ @@ -1191,7 +1192,7 @@ def test_val_dataloader_not_implemented_error(tmpdir, check_interval): ) trainer.fit(model) # verify training completed - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @pytest.mark.parametrize('check_interval', [50, 1.0]) @@ -1205,7 +1206,7 @@ def test_train_dataloader_not_implemented_error(tmpdir, check_interval): trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=check_interval) trainer.fit(model) # verify training completed - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" def test_train_dataloader_not_implemented_error_failed(tmpdir): @@ -1254,7 +1255,7 @@ def test_dataloaders_load_only_once(tmpdir): max_epochs=3, ) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert len(trainer.dev_debugger.val_dataloader_calls) == 1 assert len(trainer.dev_debugger.test_dataloader_calls) == 0 @@ -1285,7 +1286,7 @@ def test_dataloaders_load_only_once_val_interval(tmpdir): max_epochs=3, ) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" trainer.test() @@ -1329,7 +1330,7 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir): max_epochs=3, ) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert len(trainer.dev_debugger.val_dataloader_calls) == 1 assert len(trainer.dev_debugger.test_dataloader_calls) == 0 @@ -1359,7 +1360,7 @@ def test_dataloaders_load_every_epoch(tmpdir): max_epochs=3, ) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" trainer.test() @@ -1408,7 +1409,7 @@ def validation_step(self, batch, batch_idx): callbacks=[checkpoint_callback], ) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" trainer.test() @@ -1460,7 +1461,7 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): max_epochs=3, ) trainer.fit(model, train_loader, val_loader) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" trainer.test(test_dataloaders=test_loader) From 304af2b53c0fcfe4d796032476960186c210d0b0 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 18:05:03 +0200 Subject: [PATCH 07/19] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index af142fdba3414..b8b54e16d51a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - +- Fixed Parsing of Multiple Trainloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) - Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362)) From 1975cd6538bb7679e006059b4e5c8ec62f72f42b Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 18:11:02 +0200 Subject: [PATCH 08/19] Update model_train_steps.py --- tests/base/model_train_steps.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 12c60ae1ae138..c24cf5ded575a 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -11,20 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from abc import ABC from collections import OrderedDict -import torch - class TrainingStepVariations(ABC): """ Houses all variations of training steps """ - test_step_inf_loss = float('inf') - def training_step(self, batch, batch_idx, optimizer_idx=None): """Lightning calls this inside the training loop""" self.training_step_called = True @@ -49,15 +44,6 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): }) return output - def training_step__inf_loss(self, batch, batch_idx, optimizer_idx=None): - output = self.training_step(batch, batch_idx, optimizer_idx) - if batch_idx == self.test_step_inf_loss: - if isinstance(output, dict): - output['loss'] *= torch.tensor(math.inf) # make loss infinite - else: - output /= 0 - return output - def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=None): """Training step for multiple train loaders""" From ec77237e914dcb98f2f5e344a60f9bfd2d49c194 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 18:12:48 +0200 Subject: [PATCH 09/19] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 43 +++++++++++++++---------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index eb0235a0d750b..1f0317bbd5d18 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -26,7 +26,6 @@ import tests.helpers.pipelines as tpipes from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6 from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -125,7 +124,7 @@ def test_multiple_val_dataloader(tmpdir): trainer.fit(model) # verify training completed - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" # verify there are 2 val loaders assert len(trainer.val_dataloaders) == 2, 'Multiple val_dataloaders not initiated properly' @@ -194,7 +193,7 @@ def test_train_dataloader_passed_to_fit(tmpdir): fit_options = dict(train_dataloader=model.dataloader(train=True)) trainer.fit(model, **fit_options) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) @@ -220,7 +219,7 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): ) trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" assert len(trainer.val_dataloaders) == n if ckpt_path == 'specific': @@ -306,7 +305,7 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, assert sum(1 for _ in dl) == num_batches trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf')) assert epoch_cb.train_epoch_count == int(limit_train_batches > 0) assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf')) @@ -345,7 +344,7 @@ def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batch val_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.num_training_batches == limit_train_batches assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) assert epoch_cb.train_batches_seen == limit_train_batches * epochs @@ -389,7 +388,7 @@ def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches): val_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.num_val_batches[0] == limit_val_batches assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0) assert epoch_cb.val_batches_seen == limit_val_batches * epochs @@ -428,7 +427,7 @@ def test_datasets_dataloaders_with_limit_num_batches( test_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches[0] == limit_val_batches assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) @@ -611,12 +610,12 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path): # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) @@ -674,7 +673,7 @@ def test_inf_train_dataloader(tmpdir, check_interval): ) trainer.fit(model) # verify training completed - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" @pytest.mark.parametrize('check_interval', [1.0]) @@ -693,7 +692,7 @@ def test_inf_val_dataloader(tmpdir, check_interval): trainer.fit(model) # verify training completed - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" def test_error_on_zero_len_dataloader(tmpdir): @@ -1099,7 +1098,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir): replace_sampler_ddp=True, ) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, "DDP Training failed" + trainer.state.finished, "DDP Training failed" @RunIf(min_gpus=3) @@ -1151,7 +1150,7 @@ def train_dataloader(self): # we expect the reduction for the metrics also to happen on the last batch # where we will get fewer metrics than gpus trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" @pytest.mark.parametrize(['multiple_trainloader_mode', 'num_training_batches'], [ @@ -1192,7 +1191,7 @@ def test_val_dataloader_not_implemented_error(tmpdir, check_interval): ) trainer.fit(model) # verify training completed - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" @pytest.mark.parametrize('check_interval', [50, 1.0]) @@ -1206,7 +1205,7 @@ def test_train_dataloader_not_implemented_error(tmpdir, check_interval): trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=check_interval) trainer.fit(model) # verify training completed - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" def test_train_dataloader_not_implemented_error_failed(tmpdir): @@ -1255,7 +1254,7 @@ def test_dataloaders_load_only_once(tmpdir): max_epochs=3, ) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" assert len(trainer.dev_debugger.val_dataloader_calls) == 1 assert len(trainer.dev_debugger.test_dataloader_calls) == 0 @@ -1286,7 +1285,7 @@ def test_dataloaders_load_only_once_val_interval(tmpdir): max_epochs=3, ) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" trainer.test() @@ -1330,7 +1329,7 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir): max_epochs=3, ) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" assert len(trainer.dev_debugger.val_dataloader_calls) == 1 assert len(trainer.dev_debugger.test_dataloader_calls) == 0 @@ -1360,7 +1359,7 @@ def test_dataloaders_load_every_epoch(tmpdir): max_epochs=3, ) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" trainer.test() @@ -1409,7 +1408,7 @@ def validation_step(self, batch, batch_idx): callbacks=[checkpoint_callback], ) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" trainer.test() @@ -1461,7 +1460,7 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): max_epochs=3, ) trainer.fit(model, train_loader, val_loader) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + trainer.state.finished, f"Training failed with {trainer.state}" trainer.test(test_dataloaders=test_loader) From d175f3125d90fd3354dfe7dd987c81b8f7350028 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 18:13:38 +0200 Subject: [PATCH 10/19] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 43 ++++++++++++++++--------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 1f0317bbd5d18..7484fe86c23e5 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -26,6 +26,7 @@ import tests.helpers.pipelines as tpipes from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6 from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -124,7 +125,7 @@ def test_multiple_val_dataloader(tmpdir): trainer.fit(model) # verify training completed - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" # verify there are 2 val loaders assert len(trainer.val_dataloaders) == 2, 'Multiple val_dataloaders not initiated properly' @@ -193,7 +194,7 @@ def test_train_dataloader_passed_to_fit(tmpdir): fit_options = dict(train_dataloader=model.dataloader(train=True)) trainer.fit(model, **fit_options) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) @@ -219,7 +220,7 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): ) trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" assert len(trainer.val_dataloaders) == n if ckpt_path == 'specific': @@ -305,7 +306,7 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, assert sum(1 for _ in dl) == num_batches trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf')) assert epoch_cb.train_epoch_count == int(limit_train_batches > 0) assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf')) @@ -344,7 +345,7 @@ def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batch val_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.num_training_batches == limit_train_batches assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) assert epoch_cb.train_batches_seen == limit_train_batches * epochs @@ -388,7 +389,7 @@ def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches): val_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.num_val_batches[0] == limit_val_batches assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0) assert epoch_cb.val_batches_seen == limit_val_batches * epochs @@ -427,7 +428,7 @@ def test_datasets_dataloaders_with_limit_num_batches( test_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches[0] == limit_val_batches assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0) @@ -610,12 +611,12 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path): # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) @@ -673,7 +674,7 @@ def test_inf_train_dataloader(tmpdir, check_interval): ) trainer.fit(model) # verify training completed - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" @pytest.mark.parametrize('check_interval', [1.0]) @@ -692,7 +693,7 @@ def test_inf_val_dataloader(tmpdir, check_interval): trainer.fit(model) # verify training completed - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" def test_error_on_zero_len_dataloader(tmpdir): @@ -1098,7 +1099,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir): replace_sampler_ddp=True, ) trainer.fit(model) - trainer.state.finished, "DDP Training failed" + assert trainer.state.finished, "DDP Training failed" @RunIf(min_gpus=3) @@ -1150,7 +1151,7 @@ def train_dataloader(self): # we expect the reduction for the metrics also to happen on the last batch # where we will get fewer metrics than gpus trainer.fit(model) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" @pytest.mark.parametrize(['multiple_trainloader_mode', 'num_training_batches'], [ @@ -1191,7 +1192,7 @@ def test_val_dataloader_not_implemented_error(tmpdir, check_interval): ) trainer.fit(model) # verify training completed - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" @pytest.mark.parametrize('check_interval', [50, 1.0]) @@ -1205,7 +1206,7 @@ def test_train_dataloader_not_implemented_error(tmpdir, check_interval): trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=check_interval) trainer.fit(model) # verify training completed - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" def test_train_dataloader_not_implemented_error_failed(tmpdir): @@ -1254,7 +1255,7 @@ def test_dataloaders_load_only_once(tmpdir): max_epochs=3, ) trainer.fit(model) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" assert len(trainer.dev_debugger.val_dataloader_calls) == 1 assert len(trainer.dev_debugger.test_dataloader_calls) == 0 @@ -1285,7 +1286,7 @@ def test_dataloaders_load_only_once_val_interval(tmpdir): max_epochs=3, ) trainer.fit(model) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" trainer.test() @@ -1329,7 +1330,7 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir): max_epochs=3, ) trainer.fit(model) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" assert len(trainer.dev_debugger.val_dataloader_calls) == 1 assert len(trainer.dev_debugger.test_dataloader_calls) == 0 @@ -1359,7 +1360,7 @@ def test_dataloaders_load_every_epoch(tmpdir): max_epochs=3, ) trainer.fit(model) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" trainer.test() @@ -1408,7 +1409,7 @@ def validation_step(self, batch, batch_idx): callbacks=[checkpoint_callback], ) trainer.fit(model) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" trainer.test() @@ -1460,7 +1461,7 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): max_epochs=3, ) trainer.fit(model, train_loader, val_loader) - trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training failed with {trainer.state}" trainer.test(test_dataloaders=test_loader) From 1494d54ab095359e5ac7e4eb6189c57e8a3def66 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 19:20:16 +0200 Subject: [PATCH 11/19] Update supporters.py --- pytorch_lightning/trainer/supporters.py | 41 ++++++++++++++++++++----- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index d2e8b79f9ea81..51de8be1d30b1 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -262,8 +262,7 @@ def max_len(self) -> Union[int, float]: def min_len(self) -> Union[int, float]: return self._calc_num_data(self.datasets, 'min_size') - @staticmethod - def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: + def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: """ Compute the length of `CombinedDataset` according to the `mode`. @@ -281,9 +280,7 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, raise MisconfigurationException(f"Invalid Mode: {mode}") # extract the lengths - all_lengths = apply_to_collection( - datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping) - ) + all_lengths = self._get_len_recursive(datasets) compute_func = CombinedDataset.COMPUTE_FUNCS[mode] @@ -294,6 +291,34 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, return length + def _get_len_recursive(self, data): + if isinstance(data, Dataset): + return len(data) + + elif isinstance(data, (float, int)): + return data + + elif isinstance(data, Mapping): + if isinstance(list(data.values())[0], (Mapping, Sequence, Dataset, Iterable)): + return {k: self._get_len_recursive(v) for k, v in data.items()} + else: + return self._get_len(data) + elif isinstance(data, Sequence): + data = list(data) + if isinstance(data[0], (Mapping, Sequence, Dataset, Iterable)): + return [self._get_len_recursive(v) for v in data] + else: + return self._get_len(data) + + return self._get_len(data) + + @staticmethod + def _get_len(dataset): + try: + return len(dataset) + except (TypeError, NotImplementedError): + return float('inf') + def __len__(self) -> int: """Return the minimum length of the datasets.""" return self._calc_num_data(self.datasets, self.mode) @@ -335,6 +360,9 @@ def __init__(self, loaders: Any, mode: str = 'min_size'): 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones. """ + if mode not in self.SUPPORTED_MODES: + raise MisconfigurationException(f"Invalid Mode: {mode}") + self.loaders = loaders datasets = apply_to_collection( @@ -343,9 +371,6 @@ def __init__(self, loaders: Any, mode: str = 'min_size'): # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader self.dataset = CombinedDataset(datasets, mode) - if mode not in self.SUPPORTED_MODES: - raise MisconfigurationException(f"Invalid Mode: {mode}") - self.mode = mode if self.mode == 'max_size_cycle': From a8c87534f2d6f3ecec2175874785584ce099c62f Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 7 May 2021 19:21:01 +0200 Subject: [PATCH 12/19] Update test_supporters.py --- tests/trainer/test_supporters.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 6d6b1e9ad1bdf..169c8cb80b04d 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -123,8 +123,9 @@ def test_combined_dataset(dataset_1, dataset_2): def test_combined_dataset_length_mode_error(): + dset = CombinedDataset([range(10)]) with pytest.raises(MisconfigurationException, match="Invalid Mode"): - CombinedDataset._calc_num_data([range(10)], "test") + dset._calc_num_data([range(10)], "test") def test_combined_loader_iterator_dict_min_size(): @@ -146,13 +147,13 @@ def test_combined_loader_iterator_dict_min_size(): def test_combined_loader_init_mode_error(): """Test the ValueError when constructing `CombinedLoader`""" - with pytest.raises(MisconfigurationException, match="selected unsupported mode"): + with pytest.raises(MisconfigurationException, match="Invalid Mode"): CombinedLoader([range(10)], "testtt") def test_combined_loader_loader_type_error(): """Test the ValueError when wrapping the loaders""" - with pytest.raises(ValueError, match="Invalid Datatype"): + with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"): CombinedLoader(None, "max_size_cycle") From 55224838c5f221131dd80bf4b211a0e12345675a Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sat, 8 May 2021 18:56:26 +0200 Subject: [PATCH 13/19] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 5 ++++- tests/base/model_train_dataloaders.py | 2 -- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8b54e16d51a0..f554a88b9dfd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Fixed Parsing of Multiple Trainloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) +- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) + +- Fixed recursive passing of `wrong_type` keyword argument in `pytorch_lightning.utilities.apply_to_collection` ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) + - Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362)) diff --git a/tests/base/model_train_dataloaders.py b/tests/base/model_train_dataloaders.py index 4869f300231b3..0eb4b257325ff 100644 --- a/tests/base/model_train_dataloaders.py +++ b/tests/base/model_train_dataloaders.py @@ -50,8 +50,6 @@ def train_dataloader__multiple_mapping(self): self.dataloader(num_samples=50, train=True), self.dataloader(num_samples=50, train=True) ] - - # Dict[str, List[DataLoader]] loaders = {"a_b": loaders_a_b, "c_d_e": loaders_c_d_e} return loaders From a8d11d1f006da7c29dc4d60cbef1558fc08ee375 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 May 2021 16:57:05 +0000 Subject: [PATCH 14/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/supporters.py | 5 +++-- pytorch_lightning/utilities/apply_func.py | 13 ++++++++++--- tests/base/model_train_dataloaders.py | 5 +---- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 51de8be1d30b1..60fa74a539c3d 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -381,7 +381,6 @@ def sampler(self) -> Union[Iterable, Sequence, Mapping]: """Return a collections of samplers extracting from loaders.""" return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, 'sampler', None) - def _wrap_loaders_max_size_cycle(self) -> Any: """ Wraps all loaders to make sure they are cycled until the longest loader is exhausted @@ -396,7 +395,9 @@ def _wrap_loaders_max_size_cycle(self) -> Any: # multiple loaders if isinstance(self.loaders, (Sequence, Mapping)): - self.loaders = apply_to_collection(self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping)) + self.loaders = apply_to_collection( + self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping) + ) def __iter__(self) -> Any: """ diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 1b9996f3166c4..1cbab2fb8dee9 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -85,13 +85,20 @@ def apply_to_collection( # Recursively apply to collection items if isinstance(data, Mapping): - return elem_type({k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for k, v in data.items()}) + return elem_type({ + k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for k, v in data.items() + }) if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple - return elem_type(*(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data)) + return elem_type( + *(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data) + ) if isinstance(data, Sequence) and not isinstance(data, str): - return elem_type([apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data]) + return elem_type([ + apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data + ]) # data is neither of dtype, nor a collection return data diff --git a/tests/base/model_train_dataloaders.py b/tests/base/model_train_dataloaders.py index 0eb4b257325ff..0ed38a4d8d011 100644 --- a/tests/base/model_train_dataloaders.py +++ b/tests/base/model_train_dataloaders.py @@ -41,10 +41,7 @@ def train_dataloader__multiple_mapping(self): """Return a mapping loaders with different lengths""" # List[DataLoader] - loaders_a_b = [ - self.dataloader(num_samples=100, train=True), - self.dataloader(num_samples=50, train=True) - ] + loaders_a_b = [self.dataloader(num_samples=100, train=True), self.dataloader(num_samples=50, train=True)] loaders_c_d_e = [ self.dataloader(num_samples=50, train=True), self.dataloader(num_samples=50, train=True), From 27f839453b8bf9cdb5e5cd1ea9dc825fb8c7cad1 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 10 May 2021 10:33:37 +0200 Subject: [PATCH 15/19] Update tests/trainer/test_dataloaders.py Co-authored-by: Akihiro Nitta --- tests/trainer/test_dataloaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 7484fe86c23e5..5cfb10fb2f31e 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -26,7 +26,6 @@ import tests.helpers.pipelines as tpipes from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6 from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException From 614dd07afabd2a4d804afd4f7af985334b66576d Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 10 May 2021 10:40:45 +0200 Subject: [PATCH 16/19] Apply suggestions from code review Co-authored-by: Edgar Riba --- pytorch_lightning/trainer/supporters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 60fa74a539c3d..83d263eb036a0 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -291,7 +291,7 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union return length - def _get_len_recursive(self, data): + def _get_len_recursive(self, data) -> int: if isinstance(data, Dataset): return len(data) @@ -313,7 +313,7 @@ def _get_len_recursive(self, data): return self._get_len(data) @staticmethod - def _get_len(dataset): + def _get_len(dataset) -> int: try: return len(dataset) except (TypeError, NotImplementedError): From e5ef3734362c8dc216752a1819ded1de68b41838 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 10 May 2021 11:29:55 +0200 Subject: [PATCH 17/19] Update supporters.py --- pytorch_lightning/trainer/supporters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 83d263eb036a0..e53274ad8111e 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -299,13 +299,13 @@ def _get_len_recursive(self, data) -> int: return data elif isinstance(data, Mapping): - if isinstance(list(data.values())[0], (Mapping, Sequence, Dataset, Iterable)): + if iany(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data.values()): return {k: self._get_len_recursive(v) for k, v in data.items()} else: return self._get_len(data) elif isinstance(data, Sequence): data = list(data) - if isinstance(data[0], (Mapping, Sequence, Dataset, Iterable)): + if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data): return [self._get_len_recursive(v) for v in data] else: return self._get_len(data) From 00c5f91edb834e182af68cac7ad4a03389fd7f64 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 10 May 2021 11:30:17 +0200 Subject: [PATCH 18/19] Update supporters.py --- pytorch_lightning/trainer/supporters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index e53274ad8111e..c438b20be6cef 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -299,7 +299,7 @@ def _get_len_recursive(self, data) -> int: return data elif isinstance(data, Mapping): - if iany(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data.values()): + if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data.values()): return {k: self._get_len_recursive(v) for k, v in data.items()} else: return self._get_len(data) From 6295aedae40645566fc810f9e6063a2fe755a118 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 10 May 2021 13:28:28 +0200 Subject: [PATCH 19/19] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/trainer/supporters.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index c438b20be6cef..df6db1e180c24 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -301,14 +301,10 @@ def _get_len_recursive(self, data) -> int: elif isinstance(data, Mapping): if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data.values()): return {k: self._get_len_recursive(v) for k, v in data.items()} - else: - return self._get_len(data) elif isinstance(data, Sequence): data = list(data) if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data): return [self._get_len_recursive(v) for v in data] - else: - return self._get_len(data) return self._get_len(data)