Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))

- Fixed recursive passing of `wrong_type` keyword argument in `pytorch_lightning.utilities.apply_to_collection` ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433))


## [1.3.1] - 2021-05-11
Expand Down
65 changes: 36 additions & 29 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ def max_len(self) -> Union[int, float]:
def min_len(self) -> Union[int, float]:
return self._calc_num_data(self.datasets, 'min_size')

@staticmethod
def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]:
def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]:
"""
Compute the length of `CombinedDataset` according to the `mode`.

Expand All @@ -281,9 +280,7 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int,
raise MisconfigurationException(f"Invalid Mode: {mode}")

# extract the lengths
all_lengths = apply_to_collection(
datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping)
)
all_lengths = self._get_len_recursive(datasets)

compute_func = CombinedDataset.COMPUTE_FUNCS[mode]

Expand All @@ -294,6 +291,30 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int,

return length

def _get_len_recursive(self, data) -> int:
if isinstance(data, Dataset):
return len(data)

elif isinstance(data, (float, int)):
return data

elif isinstance(data, Mapping):
if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data.values()):
return {k: self._get_len_recursive(v) for k, v in data.items()}
elif isinstance(data, Sequence):
data = list(data)
if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data):
return [self._get_len_recursive(v) for v in data]

return self._get_len(data)

@staticmethod
def _get_len(dataset) -> int:
try:
return len(dataset)
except (TypeError, NotImplementedError):
return float('inf')

def __len__(self) -> int:
"""Return the minimum length of the datasets."""
return self._calc_num_data(self.datasets, self.mode)
Expand Down Expand Up @@ -335,6 +356,9 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones.

"""
if mode not in self.SUPPORTED_MODES:
raise MisconfigurationException(f"Invalid Mode: {mode}")

self.loaders = loaders

datasets = apply_to_collection(
Expand All @@ -343,9 +367,6 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
# could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
self.dataset = CombinedDataset(datasets, mode)

if mode not in self.SUPPORTED_MODES:
raise MisconfigurationException(f"Invalid Mode: {mode}")

self.mode = mode

if self.mode == 'max_size_cycle':
Expand All @@ -366,27 +387,13 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
"""
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))

if isinstance(all_lengths, (int, float)):
length = all_lengths

elif isinstance(all_lengths, Mapping):
length = max(all_lengths.values())
length = _nested_calc_num_data(all_lengths, max)

elif isinstance(all_lengths, Sequence):
length = max(all_lengths)

if isinstance(self.loaders, Mapping):
self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) for k, v in self.loaders.items()})

elif isinstance(self.loaders, Sequence):
self.loaders = type(self.loaders)([CycleIterator(v, length=length) for v in self.loaders])

# dataloaders are iterable but not sequence
elif isinstance(self.loaders, Iterable):
# only one dataloader, just keep it the same.
pass
else:
raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}')
# multiple loaders
if isinstance(self.loaders, (Sequence, Mapping)):
self.loaders = apply_to_collection(
self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping)
)

def __iter__(self) -> Any:
"""
Expand Down Expand Up @@ -490,7 +497,7 @@ def create_loader_iters(

def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable):

if isinstance(data, int):
if isinstance(data, (float, int)):
return data

if isinstance(data, Mapping):
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,20 @@ def apply_to_collection(

# Recursively apply to collection items
if isinstance(data, Mapping):
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
return elem_type({
k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
for k, v in data.items()
})

if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
return elem_type(
*(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data)
)

if isinstance(data, Sequence) and not isinstance(data, str):
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
return elem_type([
apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data
])

# data is neither of dtype, nor a collection
return data
Expand Down
15 changes: 11 additions & 4 deletions tests/base/model_train_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ def train_dataloader__zero_length(self):

def train_dataloader__multiple_mapping(self):
"""Return a mapping loaders with different lengths"""
return {
'a': self.dataloader(train=True, num_samples=100),
'b': self.dataloader(train=True, num_samples=50),
}

# List[DataLoader]
loaders_a_b = [self.dataloader(num_samples=100, train=True), self.dataloader(num_samples=50, train=True)]
loaders_c_d_e = [
self.dataloader(num_samples=50, train=True),
self.dataloader(num_samples=50, train=True),
self.dataloader(num_samples=50, train=True)
]
# Dict[str, List[DataLoader]]
loaders = {"a_b": loaders_a_b, "c_d_e": loaders_c_d_e}
return loaders
7 changes: 5 additions & 2 deletions tests/base/model_train_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=No

assert isinstance(batch, dict)
assert len(batch) == 2
assert 'a' in batch and 'b' in batch

assert 'a_b' in batch and 'c_d_e' in batch, batch.keys()
assert isinstance(batch['a_b'], list) and len(batch['a_b']) == 2
assert isinstance(batch['c_d_e'], list) and len(batch['c_d_e']) == 3

# forward pass
x, y = batch['a']
x, y = batch['a_b'][0]
x = x.view(x.size(0), -1)
y_hat = self(x)

Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
train_dl = model.dataloader(train=False)
train_dl.num_workers = 0

train_multi_dl = {'a': train_dl, 'b': train_dl}
train_multi_dl = {'a_b': [train_dl, train_dl], 'c_d_e': [train_dl, train_dl, train_dl]}
val_multi_dl = [val_dl, val_dl]
test_multi_dl = [train_dl, train_dl]

Expand Down
7 changes: 4 additions & 3 deletions tests/trainer/test_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def test_combined_dataset(dataset_1, dataset_2):


def test_combined_dataset_length_mode_error():
dset = CombinedDataset([range(10)])
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
CombinedDataset._calc_num_data([range(10)], "test")
dset._calc_num_data([range(10)], "test")


def test_combined_loader_iterator_dict_min_size():
Expand All @@ -146,13 +147,13 @@ def test_combined_loader_iterator_dict_min_size():

def test_combined_loader_init_mode_error():
"""Test the ValueError when constructing `CombinedLoader`"""
with pytest.raises(MisconfigurationException, match="selected unsupported mode"):
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
CombinedLoader([range(10)], "testtt")


def test_combined_loader_loader_type_error():
"""Test the ValueError when wrapping the loaders"""
with pytest.raises(ValueError, match="Invalid Datatype"):
with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"):
CombinedLoader(None, "max_size_cycle")


Expand Down