Skip to content

Commit c627040

Browse files
justusschockawaelchlipre-commit-ci[bot]akihironittaedgarriba
authored andcommitted
Bugfix/Multiple dataloaders (#7433)
* Update supporters.py * Update apply_func.py * Update supporters.py * Update model_train_dataloaders.py * Update model_train_steps.py * Update test_dataloaders.py * Update CHANGELOG.md * Update model_train_steps.py * Update test_dataloaders.py * Update test_dataloaders.py * Update supporters.py * Update test_supporters.py * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests/trainer/test_dataloaders.py Co-authored-by: Akihiro Nitta <[email protected]> * Apply suggestions from code review Co-authored-by: Edgar Riba <[email protected]> * Update supporters.py * Update supporters.py * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: Edgar Riba <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 80992e6 commit c627040

File tree

7 files changed

+70
-42
lines changed

7 files changed

+70
-42
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5959

6060
### Fixed
6161

62+
- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))
63+
64+
- Fixed recursive passing of `wrong_type` keyword argument in `pytorch_lightning.utilities.apply_to_collection` ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))
6265

6366
- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))
6467

pytorch_lightning/trainer/supporters.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ def max_len(self) -> Union[int, float]:
262262
def min_len(self) -> Union[int, float]:
263263
return self._calc_num_data(self.datasets, 'min_size')
264264

265-
@staticmethod
266-
def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]:
265+
def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]:
267266
"""
268267
Compute the length of `CombinedDataset` according to the `mode`.
269268
@@ -281,9 +280,7 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int,
281280
raise MisconfigurationException(f"Invalid Mode: {mode}")
282281

283282
# extract the lengths
284-
all_lengths = apply_to_collection(
285-
datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping)
286-
)
283+
all_lengths = self._get_len_recursive(datasets)
287284

288285
compute_func = CombinedDataset.COMPUTE_FUNCS[mode]
289286

@@ -294,6 +291,30 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int,
294291

295292
return length
296293

294+
def _get_len_recursive(self, data) -> int:
295+
if isinstance(data, Dataset):
296+
return len(data)
297+
298+
elif isinstance(data, (float, int)):
299+
return data
300+
301+
elif isinstance(data, Mapping):
302+
if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data.values()):
303+
return {k: self._get_len_recursive(v) for k, v in data.items()}
304+
elif isinstance(data, Sequence):
305+
data = list(data)
306+
if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data):
307+
return [self._get_len_recursive(v) for v in data]
308+
309+
return self._get_len(data)
310+
311+
@staticmethod
312+
def _get_len(dataset) -> int:
313+
try:
314+
return len(dataset)
315+
except (TypeError, NotImplementedError):
316+
return float('inf')
317+
297318
def __len__(self) -> int:
298319
"""Return the minimum length of the datasets."""
299320
return self._calc_num_data(self.datasets, self.mode)
@@ -335,6 +356,9 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
335356
'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones.
336357
337358
"""
359+
if mode not in self.SUPPORTED_MODES:
360+
raise MisconfigurationException(f"Invalid Mode: {mode}")
361+
338362
self.loaders = loaders
339363

340364
datasets = apply_to_collection(
@@ -343,9 +367,6 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
343367
# could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
344368
self.dataset = CombinedDataset(datasets, mode)
345369

346-
if mode not in self.SUPPORTED_MODES:
347-
raise MisconfigurationException(f"Invalid Mode: {mode}")
348-
349370
self.mode = mode
350371

351372
if self.mode == 'max_size_cycle':
@@ -366,27 +387,13 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
366387
"""
367388
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
368389

369-
if isinstance(all_lengths, (int, float)):
370-
length = all_lengths
371-
372-
elif isinstance(all_lengths, Mapping):
373-
length = max(all_lengths.values())
390+
length = _nested_calc_num_data(all_lengths, max)
374391

375-
elif isinstance(all_lengths, Sequence):
376-
length = max(all_lengths)
377-
378-
if isinstance(self.loaders, Mapping):
379-
self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) for k, v in self.loaders.items()})
380-
381-
elif isinstance(self.loaders, Sequence):
382-
self.loaders = type(self.loaders)([CycleIterator(v, length=length) for v in self.loaders])
383-
384-
# dataloaders are iterable but not sequence
385-
elif isinstance(self.loaders, Iterable):
386-
# only one dataloader, just keep it the same.
387-
pass
388-
else:
389-
raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}')
392+
# multiple loaders
393+
if isinstance(self.loaders, (Sequence, Mapping)):
394+
self.loaders = apply_to_collection(
395+
self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping)
396+
)
390397

391398
def __iter__(self) -> Any:
392399
"""
@@ -490,7 +497,7 @@ def create_loader_iters(
490497

491498
def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable):
492499

493-
if isinstance(data, int):
500+
if isinstance(data, (float, int)):
494501
return data
495502

496503
if isinstance(data, Mapping):

pytorch_lightning/utilities/apply_func.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,20 @@ def apply_to_collection(
8585

8686
# Recursively apply to collection items
8787
if isinstance(data, Mapping):
88-
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
88+
return elem_type({
89+
k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
90+
for k, v in data.items()
91+
})
8992

9093
if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
91-
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
94+
return elem_type(
95+
*(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data)
96+
)
9297

9398
if isinstance(data, Sequence) and not isinstance(data, str):
94-
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
99+
return elem_type([
100+
apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data
101+
])
95102

96103
# data is neither of dtype, nor a collection
97104
return data

tests/base/model_train_dataloaders.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,14 @@ def train_dataloader__zero_length(self):
3939

4040
def train_dataloader__multiple_mapping(self):
4141
"""Return a mapping loaders with different lengths"""
42-
return {
43-
'a': self.dataloader(train=True, num_samples=100),
44-
'b': self.dataloader(train=True, num_samples=50),
45-
}
42+
43+
# List[DataLoader]
44+
loaders_a_b = [self.dataloader(num_samples=100, train=True), self.dataloader(num_samples=50, train=True)]
45+
loaders_c_d_e = [
46+
self.dataloader(num_samples=50, train=True),
47+
self.dataloader(num_samples=50, train=True),
48+
self.dataloader(num_samples=50, train=True)
49+
]
50+
# Dict[str, List[DataLoader]]
51+
loaders = {"a_b": loaders_a_b, "c_d_e": loaders_c_d_e}
52+
return loaders

tests/base/model_train_steps.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,13 @@ def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=No
6363

6464
assert isinstance(batch, dict)
6565
assert len(batch) == 2
66-
assert 'a' in batch and 'b' in batch
66+
67+
assert 'a_b' in batch and 'c_d_e' in batch, batch.keys()
68+
assert isinstance(batch['a_b'], list) and len(batch['a_b']) == 2
69+
assert isinstance(batch['c_d_e'], list) and len(batch['c_d_e']) == 3
6770

6871
# forward pass
69-
x, y = batch['a']
72+
x, y = batch['a_b'][0]
7073
x = x.view(x.size(0), -1)
7174
y_hat = self(x)
7275

tests/trainer/test_dataloaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
766766
train_dl = model.dataloader(train=False)
767767
train_dl.num_workers = 0
768768

769-
train_multi_dl = {'a': train_dl, 'b': train_dl}
769+
train_multi_dl = {'a_b': [train_dl, train_dl], 'c_d_e': [train_dl, train_dl, train_dl]}
770770
val_multi_dl = [val_dl, val_dl]
771771
test_multi_dl = [train_dl, train_dl]
772772

tests/trainer/test_supporters.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,9 @@ def test_combined_dataset(dataset_1, dataset_2):
123123

124124

125125
def test_combined_dataset_length_mode_error():
126+
dset = CombinedDataset([range(10)])
126127
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
127-
CombinedDataset._calc_num_data([range(10)], "test")
128+
dset._calc_num_data([range(10)], "test")
128129

129130

130131
def test_combined_loader_iterator_dict_min_size():
@@ -146,13 +147,13 @@ def test_combined_loader_iterator_dict_min_size():
146147

147148
def test_combined_loader_init_mode_error():
148149
"""Test the ValueError when constructing `CombinedLoader`"""
149-
with pytest.raises(MisconfigurationException, match="selected unsupported mode"):
150+
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
150151
CombinedLoader([range(10)], "testtt")
151152

152153

153154
def test_combined_loader_loader_type_error():
154155
"""Test the ValueError when wrapping the loaders"""
155-
with pytest.raises(ValueError, match="Invalid Datatype"):
156+
with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"):
156157
CombinedLoader(None, "max_size_cycle")
157158

158159

0 commit comments

Comments
 (0)