diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bd9dd13a6a01..8e126d948090c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,20 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.3.2] - 2021-05-18 + +### Changed + +- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238)) + +### Fixed + +- Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) +- Fixed recursive passing of `wrong_type` keyword argument in `pytorch_lightning.utilities.apply_to_collection` ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) +- Fixed setting correct `DistribType` for `ddp_cpu` (spawn) backend ([#7492](https://github.com/PyTorchLightning/pytorch-lightning/pull/7492)) +- Fixed incorrect number of calls to LR scheduler when `check_val_every_n_epoch > 1` ([#7032](https://github.com/PyTorchLightning/pytorch-lightning/pull/7032)) + + ## [1.3.1] - 2021-05-11 ### Fixed diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index a602a75b0f877..fbb19e10a8e1e 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -168,10 +168,6 @@ Here's a more realistic, complex DataModule that shows how much more reusable th def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32) - -.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. - - --------------- LightningDataModule API @@ -228,7 +224,7 @@ There are also data operations you might want to perform on every GPU. Use setup def setup(self, stage: Optional[str] = None): # Assign Train/val split(s) for use in Dataloaders - if stage == 'fit' or stage is None: + if stage in (None, 'fit'): mnist_full = MNIST( self.data_dir, train=True, @@ -239,7 +235,7 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = self.mnist_train[0][0].shape # Assign Test split(s) for use in Dataloaders - if stage == 'test' or stage is None: + if stage in (None, 'test'): self.mnist_test = MNIST( self.data_dir, train=False, @@ -249,10 +245,17 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) -.. warning:: ``setup`` is called from every process. Setting state here is okay. - +:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` expects an ``stage: Optional[str]`` argument. +It is used to separate setup logic for ``trainer.{fit,validate,test}``. If ``setup`` is called with ``stage = None``, +we assume all stages have been set-up. +.. note:: ``setup`` is called from every process. Setting state here is okay. .. note:: ``teardown`` can be used to clean up the state. It is also called from every process +.. note:: + ``{setup,teardown,prepare_data}`` call will be only called once for a specific stage. + If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that + any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite + ``dm._has_setup_fit = False`` train_dataloader @@ -396,11 +399,12 @@ The recommended way to use a DataModule is simply: dm = MNISTDataModule() model = Model() trainer.fit(model, dm) - trainer.test(datamodule=dm) -If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning -still ensures the method runs on the correct devices) +If you need information from the dataset to build your model, then run +:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` and +:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` manually (Lightning ensures +the method runs on the correct devices). .. code-block:: python @@ -416,7 +420,7 @@ still ensures the method runs on the correct devices) ---------------- -Datamodules without Lightning +DataModules without Lightning ----------------------------- You can of course use DataModules in plain PyTorch code as well. diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 8d35e27185649..680a388ee118b 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -295,8 +295,6 @@ When your models need to know about the data, it's best to process the data befo 1. use ``prepare_data()`` to download and process the dataset. 2. use ``setup()`` to do splits, and build your model internals -| - An alternative to using a DataModule is to defer initialization of the models modules to the ``setup`` method of your LightningModule as follows: .. testcode:: diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 22b2b535f09a2..8763799131d58 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -658,10 +658,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning. transforms.Normalize((0.1307,), (0.3081,)) ]) # split dataset - if stage == 'fit': + if stage in (None, 'fit'): mnist_train = MNIST(os.getcwd(), train=True, transform=transform) self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000]) - if stage == 'test': + if stage == (None, 'test'): self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform) # return the dataloader for each split diff --git a/pytorch_lightning/__about__.py b/pytorch_lightning/__about__.py index 67ca6d9e8d167..8073e34802df2 100644 --- a/pytorch_lightning/__about__.py +++ b/pytorch_lightning/__about__.py @@ -1,7 +1,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.3.1' +__version__ = '1.3.2' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 4eaed6c7b2bd0..23626ed9cbeae 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -355,6 +355,7 @@ def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable @functools.wraps(fn) def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: name = fn.__name__ + has_run = False # If calling setup, we check the stage and assign stage-specific bool args if name in ("setup", "teardown"): @@ -366,15 +367,22 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: stage = args[0] if len(args) else kwargs.get("stage", None) if stage is None: + has_run = True for s in ("fit", "validate", "test"): - setattr(obj, f"_has_{name}_{s}", True) + attr = f"_has_{name}_{s}" + has_run &= getattr(obj, attr) + setattr(obj, attr, True) else: - setattr(obj, f"_has_{name}_{stage}", True) + attr = f"_has_{name}_{stage}" + has_run = getattr(obj, attr) + setattr(obj, attr, True) elif name == "prepare_data": + has_run = obj._has_prepared_data obj._has_prepared_data = True - return fn(*args, **kwargs) + if not has_run: + return fn(*args, **kwargs) return wrapped_fn diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 7ab0c8acbe329..7cc74f3d0452e 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -394,7 +394,7 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None) -> None: """ - Called at the beginning of fit (train + validate), validate, test, predict, or tune. + Called at the beginning of fit (train + validate), validate, test, and predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index a8a72c1831600..d826de1047851 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -522,7 +522,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): # special case with DDP on CPUs if self.distributed_backend == "ddp_cpu": - self._distrib_type = DistributedType.DDP + self._distrib_type = DistributedType.DDP_SPAWN if self.num_gpus > 0: rank_zero_warn( 'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.' diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 18a012da54760..df6db1e180c24 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -262,8 +262,7 @@ def max_len(self) -> Union[int, float]: def min_len(self) -> Union[int, float]: return self._calc_num_data(self.datasets, 'min_size') - @staticmethod - def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: + def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: """ Compute the length of `CombinedDataset` according to the `mode`. @@ -281,9 +280,7 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, raise MisconfigurationException(f"Invalid Mode: {mode}") # extract the lengths - all_lengths = apply_to_collection( - datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping) - ) + all_lengths = self._get_len_recursive(datasets) compute_func = CombinedDataset.COMPUTE_FUNCS[mode] @@ -294,6 +291,30 @@ def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, return length + def _get_len_recursive(self, data) -> int: + if isinstance(data, Dataset): + return len(data) + + elif isinstance(data, (float, int)): + return data + + elif isinstance(data, Mapping): + if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data.values()): + return {k: self._get_len_recursive(v) for k, v in data.items()} + elif isinstance(data, Sequence): + data = list(data) + if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data): + return [self._get_len_recursive(v) for v in data] + + return self._get_len(data) + + @staticmethod + def _get_len(dataset) -> int: + try: + return len(dataset) + except (TypeError, NotImplementedError): + return float('inf') + def __len__(self) -> int: """Return the minimum length of the datasets.""" return self._calc_num_data(self.datasets, self.mode) @@ -335,6 +356,9 @@ def __init__(self, loaders: Any, mode: str = 'min_size'): 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones. """ + if mode not in self.SUPPORTED_MODES: + raise MisconfigurationException(f"Invalid Mode: {mode}") + self.loaders = loaders datasets = apply_to_collection( @@ -343,9 +367,6 @@ def __init__(self, loaders: Any, mode: str = 'min_size'): # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader self.dataset = CombinedDataset(datasets, mode) - if mode not in self.SUPPORTED_MODES: - raise MisconfigurationException(f"Invalid Mode: {mode}") - self.mode = mode if self.mode == 'max_size_cycle': @@ -366,27 +387,13 @@ def _wrap_loaders_max_size_cycle(self) -> Any: """ all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) - if isinstance(all_lengths, (int, float)): - length = all_lengths - - elif isinstance(all_lengths, Mapping): - length = max(all_lengths.values()) + length = _nested_calc_num_data(all_lengths, max) - elif isinstance(all_lengths, Sequence): - length = max(all_lengths) - - if isinstance(self.loaders, Mapping): - self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) for k, v in self.loaders.items()}) - - elif isinstance(self.loaders, Sequence): - self.loaders = type(self.loaders)([CycleIterator(v, length=length) for v in self.loaders]) - - # dataloaders are iterable but not sequence - elif isinstance(self.loaders, Iterable): - # only one dataloader, just keep it the same. - pass - else: - raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}') + # multiple loaders + if isinstance(self.loaders, (Sequence, Mapping)): + self.loaders = apply_to_collection( + self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping) + ) def __iter__(self) -> Any: """ @@ -490,7 +497,7 @@ def create_loader_iters( def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable): - if isinstance(data, int): + if isinstance(data, (float, int)): return data if isinstance(data, Mapping): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2a6a53a7c192c..4e137ad9b7258 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1156,10 +1156,7 @@ def call_setup_hook(self, model: LightningModule) -> None: self.accelerator.barrier("pre_setup") if self.datamodule is not None: - called = getattr(self.datamodule, f'has_setup_{fn}') - if not called: - self.datamodule.setup(stage=fn) - + self.datamodule.setup(stage=fn) self.setup(model, stage=fn) model.setup(stage=fn) @@ -1182,10 +1179,7 @@ def call_teardown_hook(self, model: LightningModule) -> None: fn = self.state.fn._setup_fn if self.datamodule is not None: - called = getattr(self.datamodule, f'has_teardown_{fn}') - if not called: - self.datamodule.teardown(stage=fn) - + self.datamodule.teardown(stage=fn) self.profiler.teardown(stage=fn) self.teardown(stage=fn) model.teardown(stage=fn) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 790dc4c70bdeb..b3621ee176677 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -474,7 +474,6 @@ def run_training_epoch(self): train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 - val_loop_called = False batch_idx = None is_last_batch = None @@ -516,7 +515,6 @@ def run_training_epoch(self): self.trainer.validating = True self.trainer.run_evaluation() self.trainer.training = True - val_loop_called = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -565,7 +563,7 @@ def run_training_epoch(self): should_train_only = self.trainer.disable_validation or should_skip_eval # update epoch level lr_schedulers if no val loop outside train loop is triggered - if (val_loop_called and not should_check_val) or should_train_only: + if not should_check_val or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') if should_train_only: diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index e100a803bcd00..1cbab2fb8dee9 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -85,13 +85,20 @@ def apply_to_collection( # Recursively apply to collection items if isinstance(data, Mapping): - return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()}) + return elem_type({ + k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for k, v in data.items() + }) if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple - return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) + return elem_type( + *(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data) + ) if isinstance(data, Sequence) and not isinstance(data, str): - return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) + return elem_type([ + apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data + ]) # data is neither of dtype, nor a collection return data diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index a57fbb4afcbdc..50ea624dbec9d 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -437,13 +437,15 @@ def test_ipython_incompatible_backend_error(*_): with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"): Trainer(accelerator="ddp", gpus=2) - with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"): - Trainer(accelerator="ddp_cpu", num_processes=2) - with pytest.raises(MisconfigurationException, match="backend ddp2 is not compatible"): Trainer(accelerator="ddp2", gpus=2) +@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) +def test_ipython_compatible_backend(*_): + Trainer(accelerator="ddp_cpu", num_processes=2) + + @pytest.mark.parametrize( ["accelerator", "plugin"], [('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')], diff --git a/tests/base/model_train_dataloaders.py b/tests/base/model_train_dataloaders.py index 50c85ddc3f79d..0ed38a4d8d011 100644 --- a/tests/base/model_train_dataloaders.py +++ b/tests/base/model_train_dataloaders.py @@ -39,7 +39,14 @@ def train_dataloader__zero_length(self): def train_dataloader__multiple_mapping(self): """Return a mapping loaders with different lengths""" - return { - 'a': self.dataloader(train=True, num_samples=100), - 'b': self.dataloader(train=True, num_samples=50), - } + + # List[DataLoader] + loaders_a_b = [self.dataloader(num_samples=100, train=True), self.dataloader(num_samples=50, train=True)] + loaders_c_d_e = [ + self.dataloader(num_samples=50, train=True), + self.dataloader(num_samples=50, train=True), + self.dataloader(num_samples=50, train=True) + ] + # Dict[str, List[DataLoader]] + loaders = {"a_b": loaders_a_b, "c_d_e": loaders_c_d_e} + return loaders diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 2a4161a23e053..12c60ae1ae138 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -63,10 +63,13 @@ def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=No assert isinstance(batch, dict) assert len(batch) == 2 - assert 'a' in batch and 'b' in batch + + assert 'a_b' in batch and 'c_d_e' in batch, batch.keys() + assert isinstance(batch['a_b'], list) and len(batch['a_b']) == 2 + assert isinstance(batch['c_d_e'], list) and len(batch['c_d_e']) == 3 # forward pass - x, y = batch['a'] + x, y = batch['a_b'][0] x = x.view(x.size(0), -1) y_hat = self(x) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 0041ccb52c2bb..7cfa569115550 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -521,3 +521,46 @@ def test_dm_init_from_datasets_dataloaders(iterable): call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True) ]) + + +def test_datamodule_hooks_calls(tmpdir): + """Test that repeated calls to DataHooks' hooks have no effect""" + + class TestDataModule(BoringDataModule): + setup_calls = [] + teardown_calls = [] + prepare_data_calls = 0 + + def setup(self, stage=None): + super().setup(stage=stage) + self.setup_calls.append(stage) + + def teardown(self, stage=None): + super().teardown(stage=stage) + self.teardown_calls.append(stage) + + def prepare_data(self): + super().prepare_data() + self.prepare_data_calls += 1 + + dm = TestDataModule() + dm.prepare_data() + dm.prepare_data() + dm.setup('fit') + dm.setup('fit') + dm.setup() + dm.setup() + dm.teardown('validate') + dm.teardown('validate') + + assert dm.prepare_data_calls == 1 + assert dm.setup_calls == ['fit', None] + assert dm.teardown_calls == ['validate'] + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + trainer.test(BoringModel(), datamodule=dm) + + # same number of calls + assert dm.prepare_data_calls == 1 + assert dm.setup_calls == ['fit', None] + assert dm.teardown_calls == ['validate', 'test'] diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index f5b2229f8a99e..a81e0eecf5c61 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest import torch from torch import optim @@ -577,21 +579,21 @@ def configure_optimizers(self): trainer.fit(model) -class TestModel(BoringModel): +@RunIf(min_gpus=2, special=True) +def test_optimizer_state_on_device(tmpdir): + """ Test that optimizers that create state initially at instantiation still end up with the state on the GPU. """ - def configure_optimizers(self): - # Adagrad creates state tensors immediately, model is not yet on GPU. - return optim.Adagrad(self.parameters()) + class TestModel(BoringModel): - def on_train_start(self, *args, **kwargs): - opt = self.optimizers() - _, state = next(iter(opt.state.items())) - assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device + def configure_optimizers(self): + # Adagrad creates state tensors immediately, model is not yet on GPU. + return optim.Adagrad(self.parameters()) + def on_train_start(self, *args, **kwargs): + opt = self.optimizers() + _, state = next(iter(opt.state.items())) + assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device -@RunIf(min_gpus=2, special=True) -def test_optimizer_state_on_device(tmpdir): - """ Test that optimizers that create state initially at instantiation still end up with the state on the GPU. """ model = TestModel() trainer = Trainer( default_root_dir=tmpdir, @@ -600,3 +602,21 @@ def test_optimizer_state_on_device(tmpdir): fast_dev_run=True, ) trainer.fit(model) + + +@pytest.mark.parametrize("check_val_every_n_epoch", [1, 2]) +@mock.patch("torch.optim.lr_scheduler.StepLR.step") +def test_lr_scheduler_epoch_step_frequency(mocked_sched, check_val_every_n_epoch, tmpdir): + epochs = 4 + expected_steps = epochs + 1 # every LRScheduler gets called once at init + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + check_val_every_n_epoch=check_val_every_n_epoch, + max_epochs=epochs, + ) + trainer.fit(model) + assert mocked_sched.call_count == expected_steps diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 6f78f125754b5..6caf7ee132300 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -766,7 +766,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): train_dl = model.dataloader(train=False) train_dl.num_workers = 0 - train_multi_dl = {'a': train_dl, 'b': train_dl} + train_multi_dl = {'a_b': [train_dl, train_dl], 'c_d_e': [train_dl, train_dl, train_dl]} val_multi_dl = [val_dl, val_dl] test_multi_dl = [train_dl, train_dl] diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 6d6b1e9ad1bdf..169c8cb80b04d 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -123,8 +123,9 @@ def test_combined_dataset(dataset_1, dataset_2): def test_combined_dataset_length_mode_error(): + dset = CombinedDataset([range(10)]) with pytest.raises(MisconfigurationException, match="Invalid Mode"): - CombinedDataset._calc_num_data([range(10)], "test") + dset._calc_num_data([range(10)], "test") def test_combined_loader_iterator_dict_min_size(): @@ -146,13 +147,13 @@ def test_combined_loader_iterator_dict_min_size(): def test_combined_loader_init_mode_error(): """Test the ValueError when constructing `CombinedLoader`""" - with pytest.raises(MisconfigurationException, match="selected unsupported mode"): + with pytest.raises(MisconfigurationException, match="Invalid Mode"): CombinedLoader([range(10)], "testtt") def test_combined_loader_loader_type_error(): """Test the ValueError when wrapping the loaders""" - with pytest.raises(ValueError, match="Invalid Datatype"): + with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"): CombinedLoader(None, "max_size_cycle") diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b717302adf31f..96a828b44ddf0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1140,18 +1140,6 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): num_processes=1, ), ), - ( - dict(accelerator="dp", gpus=None), - dict( - use_dp=False, - use_ddp=False, - use_ddp2=False, - num_gpus=0, - on_gpu=False, - use_single_gpu=False, - num_processes=1, - ), - ), ( dict(accelerator="ddp", gpus=None), dict(