Skip to content

Commit ef7345d

Browse files
authored
add possibility for nested loaders (#5404)
* add possibility for nested loaders * pep8: newline
1 parent 6386f45 commit ef7345d

File tree

3 files changed

+94
-47
lines changed

3 files changed

+94
-47
lines changed

pytorch_lightning/loggers/wandb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
_WANDB_AVAILABLE = _module_available("wandb")
3030

3131
try:
32-
import wandb
3332
from wandb.wandb_run import Run
33+
34+
import wandb
3435
except ImportError:
3536
# needed for test mocks, these tests shall be updated
3637
wandb, Run = None, None

pytorch_lightning/trainer/supporters.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import os
1616
from collections.abc import Iterable, Iterator, Mapping, Sequence
17-
from typing import Any, Optional, Union
17+
from typing import Any, Callable, Optional, Union
1818

1919
import torch
2020
from torch import Tensor
@@ -306,12 +306,8 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int,
306306

307307
if isinstance(all_lengths, (int, float)):
308308
length = all_lengths
309-
310-
elif isinstance(all_lengths, Mapping):
311-
length = compute_func(all_lengths.values())
312-
313-
elif isinstance(all_lengths, Sequence):
314-
length = compute_func(all_lengths)
309+
else:
310+
length = _nested_calc_num_data(all_lengths, compute_func)
315311

316312
return length
317313

@@ -437,13 +433,8 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]:
437433
if isinstance(all_lengths, (int, float)):
438434
return all_lengths
439435

440-
elif isinstance(all_lengths, Mapping):
441-
return min(all_lengths.values())
442-
443-
elif isinstance(all_lengths, Sequence):
444-
return min(all_lengths)
445-
446-
raise TypeError(f'Got Type {type(all_lengths).__name__}, but expected one of Sequence, int or Mapping')
436+
else:
437+
return _nested_calc_num_data(all_lengths, min)
447438

448439
def __len__(self) -> int:
449440
return self._calc_num_batches(self.loaders)
@@ -516,3 +507,25 @@ def create_loader_iters(
516507
"""
517508
# dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences
518509
return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))
510+
511+
512+
def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable):
513+
514+
if isinstance(data, int):
515+
return data
516+
517+
if isinstance(data, Mapping):
518+
data = list(data.values())
519+
520+
if not isinstance(data, Sequence):
521+
raise TypeError(f'Expected data to be int, Sequence or Mapping, but got {type(data).__name__}')
522+
523+
new_data = []
524+
525+
for x in data:
526+
if isinstance(x, (Mapping, Sequence)):
527+
new_data.append(_nested_calc_num_data(x, compute_func))
528+
else:
529+
new_data.append(x)
530+
531+
return compute_func(new_data)

tests/trainer/test_supporters.py

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch.utils.data import TensorDataset
1919

2020
from pytorch_lightning.trainer.supporters import (
21+
_nested_calc_num_data,
2122
CombinedDataset,
2223
CombinedLoader,
2324
CombinedLoaderIterator,
@@ -61,7 +62,7 @@ def test_cycle_iterator():
6162
def test_none_length_cycle_iterator():
6263
"""Test the infinite cycling function of `CycleIterator`"""
6364
iterator = CycleIterator(range(100))
64-
assert iterator.__len__() == float('inf')
65+
assert iterator.__len__() == float("inf")
6566

6667
# test infinite loop
6768
for idx, item in enumerate(iterator):
@@ -70,12 +71,15 @@ def test_none_length_cycle_iterator():
7071
assert item == 0
7172

7273

73-
@pytest.mark.parametrize(['dataset_1', 'dataset_2'], [
74-
([list(range(10)), list(range(20))]),
75-
([range(10), range(20)]),
76-
([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]),
77-
([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))])
78-
])
74+
@pytest.mark.parametrize(
75+
["dataset_1", "dataset_2"],
76+
[
77+
([list(range(10)), list(range(20))]),
78+
([range(10), range(20)]),
79+
([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]),
80+
([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]),
81+
],
82+
)
7983
def test_combined_dataset(dataset_1, dataset_2):
8084
"""Verify the length of the CombinedDataset"""
8185
datasets = [dataset_1, dataset_2]
@@ -86,83 +90,91 @@ def test_combined_dataset(dataset_1, dataset_2):
8690

8791

8892
def test_combined_dataset_length_mode_error():
89-
with pytest.raises(MisconfigurationException, match='Invalid Mode'):
90-
CombinedDataset._calc_num_data([range(10)], 'test')
93+
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
94+
CombinedDataset._calc_num_data([range(10)], "test")
9195

9296

9397
def test_combined_loader_iterator_dict_min_size():
9498
"""Test `CombinedLoaderIterator` given mapping loaders"""
95-
loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4),
96-
'b': torch.utils.data.DataLoader(range(20), batch_size=5)}
99+
loaders = {
100+
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
101+
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
102+
}
97103

98104
combined_iter = CombinedLoaderIterator(loaders)
99105

100106
for idx, item in enumerate(combined_iter):
101107
assert isinstance(item, dict)
102108
assert len(item) == 2
103-
assert 'a' in item and 'b' in item
109+
assert "a" in item and "b" in item
104110

105-
assert idx == min(len(loaders['a']), len(loaders['b'])) - 1
111+
assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1
106112

107113

108114
def test_combined_loader_init_mode_error():
109115
"""Test the ValueError when constructing `CombinedLoader`"""
110-
with pytest.raises(MisconfigurationException, match='selected unsupported mode'):
111-
CombinedLoader([range(10)], 'testtt')
116+
with pytest.raises(MisconfigurationException, match="selected unsupported mode"):
117+
CombinedLoader([range(10)], "testtt")
112118

113119

114120
def test_combined_loader_loader_type_error():
115121
"""Test the ValueError when wrapping the loaders"""
116-
with pytest.raises(ValueError, match='Invalid Datatype'):
117-
CombinedLoader(None, 'max_size_cycle')
122+
with pytest.raises(ValueError, match="Invalid Datatype"):
123+
CombinedLoader(None, "max_size_cycle")
118124

119125

120126
def test_combined_loader_calc_length_mode_error():
121127
"""Test the ValueError when calculating the number of batches"""
122-
with pytest.raises(TypeError, match='Got Type NoneType, but expected one of Sequence, int or Mapping'):
128+
with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"):
123129
CombinedLoader._calc_num_batches(None)
124130

125131

126132
def test_combined_loader_dict_min_size():
127133
"""Test `CombinedLoader` of mode 'min_size' given mapping loaders"""
128-
loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4),
129-
'b': torch.utils.data.DataLoader(range(20), batch_size=5)}
134+
loaders = {
135+
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
136+
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
137+
}
130138

131-
combined_loader = CombinedLoader(loaders, 'min_size')
139+
combined_loader = CombinedLoader(loaders, "min_size")
132140

133141
assert len(combined_loader) == min([len(v) for v in loaders.values()])
134142

135143
for idx, item in enumerate(combined_loader):
136144
assert isinstance(item, dict)
137145
assert len(item) == 2
138-
assert 'a' in item and 'b' in item
146+
assert "a" in item and "b" in item
139147

140148
assert idx == len(combined_loader) - 1
141149

142150

143151
def test_combined_loader_dict_max_size_cycle():
144152
"""Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders"""
145-
loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4),
146-
'b': torch.utils.data.DataLoader(range(20), batch_size=5)}
153+
loaders = {
154+
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
155+
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
156+
}
147157

148-
combined_loader = CombinedLoader(loaders, 'max_size_cycle')
158+
combined_loader = CombinedLoader(loaders, "max_size_cycle")
149159

150160
assert len(combined_loader) == max([len(v) for v in loaders.values()])
151161

152162
for idx, item in enumerate(combined_loader):
153163
assert isinstance(item, dict)
154164
assert len(item) == 2
155-
assert 'a' in item and 'b' in item
165+
assert "a" in item and "b" in item
156166

157167
assert idx == len(combined_loader) - 1
158168

159169

160170
def test_combined_loader_sequence_min_size():
161171
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders"""
162-
loaders = [torch.utils.data.DataLoader(range(10), batch_size=4),
163-
torch.utils.data.DataLoader(range(20), batch_size=5)]
172+
loaders = [
173+
torch.utils.data.DataLoader(range(10), batch_size=4),
174+
torch.utils.data.DataLoader(range(20), batch_size=5),
175+
]
164176

165-
combined_loader = CombinedLoader(loaders, 'min_size')
177+
combined_loader = CombinedLoader(loaders, "min_size")
166178

167179
assert len(combined_loader) == min([len(v) for v in loaders])
168180

@@ -175,10 +187,12 @@ def test_combined_loader_sequence_min_size():
175187

176188
def test_combined_loader_sequence_max_size_cycle():
177189
"""Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders"""
178-
loaders = [torch.utils.data.DataLoader(range(10), batch_size=4),
179-
torch.utils.data.DataLoader(range(20), batch_size=5)]
190+
loaders = [
191+
torch.utils.data.DataLoader(range(10), batch_size=4),
192+
torch.utils.data.DataLoader(range(20), batch_size=5),
193+
]
180194

181-
combined_loader = CombinedLoader(loaders, 'max_size_cycle')
195+
combined_loader = CombinedLoader(loaders, "max_size_cycle")
182196

183197
assert len(combined_loader) == max([len(v) for v in loaders])
184198

@@ -187,3 +201,22 @@ def test_combined_loader_sequence_max_size_cycle():
187201
assert len(item) == 2
188202

189203
assert idx == len(combined_loader) - 1
204+
205+
206+
@pytest.mark.parametrize(
207+
["input_data", "compute_func", "expected_length"],
208+
[
209+
([*range(10), list(range(1, 20))], min, 0),
210+
([*range(10), list(range(1, 20))], max, 19),
211+
([*range(10), {str(i): i for i in range(1, 20)}], min, 0),
212+
([*range(10), {str(i): i for i in range(1, 20)}], max, 19),
213+
({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, min, 0),
214+
({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, max, 19),
215+
({**{str(i): i for i in range(10)}, "nested": list(range(20))}, min, 0),
216+
({**{str(i): i for i in range(10)}, "nested": list(range(20))}, max, 19),
217+
],
218+
)
219+
def test_nested_calc_num_data(input_data, compute_func, expected_length):
220+
calculated_length = _nested_calc_num_data(input_data, compute_func)
221+
222+
assert calculated_length == expected_length

0 commit comments

Comments
 (0)