diff --git a/CHANGELOG.md b/CHANGELOG.md index 197c9fa4ff864..d3c8f5c4e1672 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) + +- Fixed recursive passing of `wrong_type` keyword argument in `pytorch_lightning.utilities.apply_to_collection` ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) ## [1.3.1] - 2021-05-11 diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 18a012da54760..df6db1e180c24 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -262,8 +262,7 @@ def max_len(self) -> Union[int, float]: def min_len(self) -> Union[int, float]: return self._calc_num_data(self.datasets, 'min_size') - @staticmethod - def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: + def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: """ Compute the length of `CombinedDataset` according to the `mode`. @@ -281,9 +280,7 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, raise MisconfigurationException(f"Invalid Mode: {mode}") # extract the lengths - all_lengths = apply_to_collection( - datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping) - ) + all_lengths = self._get_len_recursive(datasets) compute_func = CombinedDataset.COMPUTE_FUNCS[mode] @@ -294,6 +291,30 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, return length + def _get_len_recursive(self, data) -> int: + if isinstance(data, Dataset): + return len(data) + + elif isinstance(data, (float, int)): + return data + + elif isinstance(data, Mapping): + if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data.values()): + return {k: self._get_len_recursive(v) for k, v in data.items()} + elif isinstance(data, Sequence): + data = list(data) + if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data): + return [self._get_len_recursive(v) for v in data] + + return self._get_len(data) + + @staticmethod + def _get_len(dataset) -> int: + try: + return len(dataset) + except (TypeError, NotImplementedError): + return float('inf') + def __len__(self) -> int: """Return the minimum length of the datasets.""" return self._calc_num_data(self.datasets, self.mode) @@ -335,6 +356,9 @@ def __init__(self, loaders: Any, mode: str = 'min_size'): 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones. """ + if mode not in self.SUPPORTED_MODES: + raise MisconfigurationException(f"Invalid Mode: {mode}") + self.loaders = loaders datasets = apply_to_collection( @@ -343,9 +367,6 @@ def __init__(self, loaders: Any, mode: str = 'min_size'): # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader self.dataset = CombinedDataset(datasets, mode) - if mode not in self.SUPPORTED_MODES: - raise MisconfigurationException(f"Invalid Mode: {mode}") - self.mode = mode if self.mode == 'max_size_cycle': @@ -366,27 +387,13 @@ def _wrap_loaders_max_size_cycle(self) -> Any: """ all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) - if isinstance(all_lengths, (int, float)): - length = all_lengths - - elif isinstance(all_lengths, Mapping): - length = max(all_lengths.values()) + length = _nested_calc_num_data(all_lengths, max) - elif isinstance(all_lengths, Sequence): - length = max(all_lengths) - - if isinstance(self.loaders, Mapping): - self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) for k, v in self.loaders.items()}) - - elif isinstance(self.loaders, Sequence): - self.loaders = type(self.loaders)([CycleIterator(v, length=length) for v in self.loaders]) - - # dataloaders are iterable but not sequence - elif isinstance(self.loaders, Iterable): - # only one dataloader, just keep it the same. - pass - else: - raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}') + # multiple loaders + if isinstance(self.loaders, (Sequence, Mapping)): + self.loaders = apply_to_collection( + self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping) + ) def __iter__(self) -> Any: """ @@ -490,7 +497,7 @@ def create_loader_iters( def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable): - if isinstance(data, int): + if isinstance(data, (float, int)): return data if isinstance(data, Mapping): diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index e100a803bcd00..1cbab2fb8dee9 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -85,13 +85,20 @@ def apply_to_collection( # Recursively apply to collection items if isinstance(data, Mapping): - return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()}) + return elem_type({ + k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for k, v in data.items() + }) if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple - return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) + return elem_type( + *(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data) + ) if isinstance(data, Sequence) and not isinstance(data, str): - return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) + return elem_type([ + apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data + ]) # data is neither of dtype, nor a collection return data diff --git a/tests/base/model_train_dataloaders.py b/tests/base/model_train_dataloaders.py index 50c85ddc3f79d..0ed38a4d8d011 100644 --- a/tests/base/model_train_dataloaders.py +++ b/tests/base/model_train_dataloaders.py @@ -39,7 +39,14 @@ def train_dataloader__zero_length(self): def train_dataloader__multiple_mapping(self): """Return a mapping loaders with different lengths""" - return { - 'a': self.dataloader(train=True, num_samples=100), - 'b': self.dataloader(train=True, num_samples=50), - } + + # List[DataLoader] + loaders_a_b = [self.dataloader(num_samples=100, train=True), self.dataloader(num_samples=50, train=True)] + loaders_c_d_e = [ + self.dataloader(num_samples=50, train=True), + self.dataloader(num_samples=50, train=True), + self.dataloader(num_samples=50, train=True) + ] + # Dict[str, List[DataLoader]] + loaders = {"a_b": loaders_a_b, "c_d_e": loaders_c_d_e} + return loaders diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 9e26e9fc93bae..c24cf5ded575a 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -49,10 +49,13 @@ def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=No assert isinstance(batch, dict) assert len(batch) == 2 - assert 'a' in batch and 'b' in batch + + assert 'a_b' in batch and 'c_d_e' in batch, batch.keys() + assert isinstance(batch['a_b'], list) and len(batch['a_b']) == 2 + assert isinstance(batch['c_d_e'], list) and len(batch['c_d_e']) == 3 # forward pass - x, y = batch['a'] + x, y = batch['a_b'][0] x = x.view(x.size(0), -1) y_hat = self(x) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 6f78f125754b5..6caf7ee132300 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -766,7 +766,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): train_dl = model.dataloader(train=False) train_dl.num_workers = 0 - train_multi_dl = {'a': train_dl, 'b': train_dl} + train_multi_dl = {'a_b': [train_dl, train_dl], 'c_d_e': [train_dl, train_dl, train_dl]} val_multi_dl = [val_dl, val_dl] test_multi_dl = [train_dl, train_dl] diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 6d6b1e9ad1bdf..169c8cb80b04d 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -123,8 +123,9 @@ def test_combined_dataset(dataset_1, dataset_2): def test_combined_dataset_length_mode_error(): + dset = CombinedDataset([range(10)]) with pytest.raises(MisconfigurationException, match="Invalid Mode"): - CombinedDataset._calc_num_data([range(10)], "test") + dset._calc_num_data([range(10)], "test") def test_combined_loader_iterator_dict_min_size(): @@ -146,13 +147,13 @@ def test_combined_loader_iterator_dict_min_size(): def test_combined_loader_init_mode_error(): """Test the ValueError when constructing `CombinedLoader`""" - with pytest.raises(MisconfigurationException, match="selected unsupported mode"): + with pytest.raises(MisconfigurationException, match="Invalid Mode"): CombinedLoader([range(10)], "testtt") def test_combined_loader_loader_type_error(): """Test the ValueError when wrapping the loaders""" - with pytest.raises(ValueError, match="Invalid Datatype"): + with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"): CombinedLoader(None, "max_size_cycle")