diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 15b46a4d9bbd1..443c50ac4835f 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -29,8 +29,9 @@ _WANDB_AVAILABLE = _module_available("wandb") try: - import wandb from wandb.wandb_run import Run + + import wandb except ImportError: # needed for test mocks, these tests shall be updated wandb, Run = None, None diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 4c14f01640b1b..3cd29c5203032 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -14,7 +14,7 @@ import os from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch from torch import Tensor @@ -306,12 +306,8 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, if isinstance(all_lengths, (int, float)): length = all_lengths - - elif isinstance(all_lengths, Mapping): - length = compute_func(all_lengths.values()) - - elif isinstance(all_lengths, Sequence): - length = compute_func(all_lengths) + else: + length = _nested_calc_num_data(all_lengths, compute_func) return length @@ -437,13 +433,8 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]: if isinstance(all_lengths, (int, float)): return all_lengths - elif isinstance(all_lengths, Mapping): - return min(all_lengths.values()) - - elif isinstance(all_lengths, Sequence): - return min(all_lengths) - - raise TypeError(f'Got Type {type(all_lengths).__name__}, but expected one of Sequence, int or Mapping') + else: + return _nested_calc_num_data(all_lengths, min) def __len__(self) -> int: return self._calc_num_batches(self.loaders) @@ -516,3 +507,25 @@ def create_loader_iters( """ # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping)) + + +def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable): + + if isinstance(data, int): + return data + + if isinstance(data, Mapping): + data = list(data.values()) + + if not isinstance(data, Sequence): + raise TypeError(f'Expected data to be int, Sequence or Mapping, but got {type(data).__name__}') + + new_data = [] + + for x in data: + if isinstance(x, (Mapping, Sequence)): + new_data.append(_nested_calc_num_data(x, compute_func)) + else: + new_data.append(x) + + return compute_func(new_data) diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index f820e4a4f2ce2..0311a789c5782 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -18,6 +18,7 @@ from torch.utils.data import TensorDataset from pytorch_lightning.trainer.supporters import ( + _nested_calc_num_data, CombinedDataset, CombinedLoader, CombinedLoaderIterator, @@ -61,7 +62,7 @@ def test_cycle_iterator(): def test_none_length_cycle_iterator(): """Test the infinite cycling function of `CycleIterator`""" iterator = CycleIterator(range(100)) - assert iterator.__len__() == float('inf') + assert iterator.__len__() == float("inf") # test infinite loop for idx, item in enumerate(iterator): @@ -70,12 +71,15 @@ def test_none_length_cycle_iterator(): assert item == 0 -@pytest.mark.parametrize(['dataset_1', 'dataset_2'], [ - ([list(range(10)), list(range(20))]), - ([range(10), range(20)]), - ([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]), - ([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]) -]) +@pytest.mark.parametrize( + ["dataset_1", "dataset_2"], + [ + ([list(range(10)), list(range(20))]), + ([range(10), range(20)]), + ([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]), + ([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]), + ], +) def test_combined_dataset(dataset_1, dataset_2): """Verify the length of the CombinedDataset""" datasets = [dataset_1, dataset_2] @@ -86,83 +90,91 @@ def test_combined_dataset(dataset_1, dataset_2): def test_combined_dataset_length_mode_error(): - with pytest.raises(MisconfigurationException, match='Invalid Mode'): - CombinedDataset._calc_num_data([range(10)], 'test') + with pytest.raises(MisconfigurationException, match="Invalid Mode"): + CombinedDataset._calc_num_data([range(10)], "test") def test_combined_loader_iterator_dict_min_size(): """Test `CombinedLoaderIterator` given mapping loaders""" - loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), - 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + loaders = { + "a": torch.utils.data.DataLoader(range(10), batch_size=4), + "b": torch.utils.data.DataLoader(range(20), batch_size=5), + } combined_iter = CombinedLoaderIterator(loaders) for idx, item in enumerate(combined_iter): assert isinstance(item, dict) assert len(item) == 2 - assert 'a' in item and 'b' in item + assert "a" in item and "b" in item - assert idx == min(len(loaders['a']), len(loaders['b'])) - 1 + assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1 def test_combined_loader_init_mode_error(): """Test the ValueError when constructing `CombinedLoader`""" - with pytest.raises(MisconfigurationException, match='selected unsupported mode'): - CombinedLoader([range(10)], 'testtt') + with pytest.raises(MisconfigurationException, match="selected unsupported mode"): + CombinedLoader([range(10)], "testtt") def test_combined_loader_loader_type_error(): """Test the ValueError when wrapping the loaders""" - with pytest.raises(ValueError, match='Invalid Datatype'): - CombinedLoader(None, 'max_size_cycle') + with pytest.raises(ValueError, match="Invalid Datatype"): + CombinedLoader(None, "max_size_cycle") def test_combined_loader_calc_length_mode_error(): """Test the ValueError when calculating the number of batches""" - with pytest.raises(TypeError, match='Got Type NoneType, but expected one of Sequence, int or Mapping'): + with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"): CombinedLoader._calc_num_batches(None) def test_combined_loader_dict_min_size(): """Test `CombinedLoader` of mode 'min_size' given mapping loaders""" - loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), - 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + loaders = { + "a": torch.utils.data.DataLoader(range(10), batch_size=4), + "b": torch.utils.data.DataLoader(range(20), batch_size=5), + } - combined_loader = CombinedLoader(loaders, 'min_size') + combined_loader = CombinedLoader(loaders, "min_size") assert len(combined_loader) == min([len(v) for v in loaders.values()]) for idx, item in enumerate(combined_loader): assert isinstance(item, dict) assert len(item) == 2 - assert 'a' in item and 'b' in item + assert "a" in item and "b" in item assert idx == len(combined_loader) - 1 def test_combined_loader_dict_max_size_cycle(): """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders""" - loaders = {'a': torch.utils.data.DataLoader(range(10), batch_size=4), - 'b': torch.utils.data.DataLoader(range(20), batch_size=5)} + loaders = { + "a": torch.utils.data.DataLoader(range(10), batch_size=4), + "b": torch.utils.data.DataLoader(range(20), batch_size=5), + } - combined_loader = CombinedLoader(loaders, 'max_size_cycle') + combined_loader = CombinedLoader(loaders, "max_size_cycle") assert len(combined_loader) == max([len(v) for v in loaders.values()]) for idx, item in enumerate(combined_loader): assert isinstance(item, dict) assert len(item) == 2 - assert 'a' in item and 'b' in item + assert "a" in item and "b" in item assert idx == len(combined_loader) - 1 def test_combined_loader_sequence_min_size(): """Test `CombinedLoader` of mode 'min_size' given sequence loaders""" - loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), - torch.utils.data.DataLoader(range(20), batch_size=5)] + loaders = [ + torch.utils.data.DataLoader(range(10), batch_size=4), + torch.utils.data.DataLoader(range(20), batch_size=5), + ] - combined_loader = CombinedLoader(loaders, 'min_size') + combined_loader = CombinedLoader(loaders, "min_size") assert len(combined_loader) == min([len(v) for v in loaders]) @@ -175,10 +187,12 @@ def test_combined_loader_sequence_min_size(): def test_combined_loader_sequence_max_size_cycle(): """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders""" - loaders = [torch.utils.data.DataLoader(range(10), batch_size=4), - torch.utils.data.DataLoader(range(20), batch_size=5)] + loaders = [ + torch.utils.data.DataLoader(range(10), batch_size=4), + torch.utils.data.DataLoader(range(20), batch_size=5), + ] - combined_loader = CombinedLoader(loaders, 'max_size_cycle') + combined_loader = CombinedLoader(loaders, "max_size_cycle") assert len(combined_loader) == max([len(v) for v in loaders]) @@ -187,3 +201,22 @@ def test_combined_loader_sequence_max_size_cycle(): assert len(item) == 2 assert idx == len(combined_loader) - 1 + + +@pytest.mark.parametrize( + ["input_data", "compute_func", "expected_length"], + [ + ([*range(10), list(range(1, 20))], min, 0), + ([*range(10), list(range(1, 20))], max, 19), + ([*range(10), {str(i): i for i in range(1, 20)}], min, 0), + ([*range(10), {str(i): i for i in range(1, 20)}], max, 19), + ({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, min, 0), + ({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, max, 19), + ({**{str(i): i for i in range(10)}, "nested": list(range(20))}, min, 0), + ({**{str(i): i for i in range(10)}, "nested": list(range(20))}, max, 19), + ], +) +def test_nested_calc_num_data(input_data, compute_func, expected_length): + calculated_length = _nested_calc_num_data(input_data, compute_func) + + assert calculated_length == expected_length