From 8862995f1d6807005dc644c000009f7a6807170e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 7 Apr 2021 19:29:42 +0100 Subject: [PATCH 1/9] Add error message with TPUSpawn + IterableDataset --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 68068935127e2..0b6695739120f 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -24,6 +24,7 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -65,6 +66,10 @@ def is_distributed(self): return self.world_size != 1 def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader: + if not has_len(dataloader): + raise MisconfigurationException( + "TPUSpawn does not currently support IterableDatasets, the dataset must implement __len__." + ) device = xm.xla_device() dataloader = MpDeviceLoader(dataloader, device) return dataloader From 159e16f391435a617143ac0a7dde5dae6c587465 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 7 Apr 2021 20:53:38 +0100 Subject: [PATCH 2/9] Update --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 0b6695739120f..f9aa267794181 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -68,7 +68,7 @@ def is_distributed(self): def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader: if not has_len(dataloader): raise MisconfigurationException( - "TPUSpawn does not currently support IterableDatasets, the dataset must implement __len__." + "TPUSpawn does not currently support IterableDataset objects, the dataset must implement __len__." ) device = xm.xla_device() dataloader = MpDeviceLoader(dataloader, device) From 0058b902f95a74e3ba786b73094d7373d039a13c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 7 Apr 2021 21:12:13 +0100 Subject: [PATCH 3/9] Update typing --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index f9aa267794181..86285de56275a 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -65,7 +65,7 @@ def distributed_sampler_kwargs(self) -> dict: def is_distributed(self): return self.world_size != 1 - def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader: + def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> MpDeviceLoader: if not has_len(dataloader): raise MisconfigurationException( "TPUSpawn does not currently support IterableDataset objects, the dataset must implement __len__." From d1fe5a3ff14f9f07faa9813cd802f6fa4e00c20d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 8 Apr 2021 09:20:05 +0100 Subject: [PATCH 4/9] Add additional error and tests --- .../plugins/training_type/tpu_spawn.py | 46 ++++++++++-- tests/plugins/test_tpu_spawn.py | 74 +++++++++++++++++++ 2 files changed, 114 insertions(+), 6 deletions(-) create mode 100644 tests/plugins/test_tpu_spawn.py diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 86285de56275a..381ee0e0dd617 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -15,12 +15,13 @@ import os import re import time -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING import torch import torch.multiprocessing as mp from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -41,6 +42,11 @@ from omegaconf import DictConfig, ListConfig, OmegaConf +if TYPE_CHECKING: + from torch.nn import Module + from torch.utils.data import DataLoader + + class TPUSpawnPlugin(DDPSpawnPlugin): def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None: @@ -48,6 +54,37 @@ def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[ self.tpu_local_core_rank = 0 self.start_method = None + @staticmethod + def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']): + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + for dataloader in dataloaders: + if not has_len(dataloader): + raise MisconfigurationException( + "TPUSpawn does not currently support IterableDataset objects, the dataset must implement __len__." + ) + + @staticmethod + def _validate_patched_dataloaders(model: 'Module') -> None: + """Validate and fail fast if the dataloaders were passed directly to fit. + """ + if isinstance(model.train_dataloader, _PatchDataLoader): + TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader) + + if isinstance(model.val_dataloader, _PatchDataLoader): + TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader) + + if isinstance(model.test_dataloader, _PatchDataLoader): + TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader) + + if isinstance(model.predict_dataloader, _PatchDataLoader): + TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader) + + def connect(self, model: 'Module') -> None: + TPUSpawnPlugin._validate_patched_dataloaders(model) + return super().connect(model) + def setup(self, model: torch.nn.Module) -> torch.nn.Module: self.create_mp_queue() return self.model @@ -65,11 +102,8 @@ def distributed_sampler_kwargs(self) -> dict: def is_distributed(self): return self.world_size != 1 - def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> MpDeviceLoader: - if not has_len(dataloader): - raise MisconfigurationException( - "TPUSpawn does not currently support IterableDataset objects, the dataset must implement __len__." - ) + def process_dataloader(self, dataloader: 'DataLoader') -> MpDeviceLoader: + TPUSpawnPlugin._validate_dataloader(dataloader) device = xm.xla_device() dataloader = MpDeviceLoader(dataloader, device) return dataloader diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py new file mode 100644 index 0000000000000..d141f5e43ae64 --- /dev/null +++ b/tests/plugins/test_tpu_spawn.py @@ -0,0 +1,74 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import MagicMock + +import pytest +from torch.utils.data import DataLoader + +from pytorch_lightning.plugins.training_type import TPUSpawnPlugin +from pytorch_lightning.trainer.connectors.data_connector import DataConnector +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel, RandomDataset +from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader + + +class BoringModelNoDataloaders(BoringModel): + def train_dataloader(self): + raise NotImplementedError + + def val_dataloader(self): + raise NotImplementedError + + def test_dataloader(self): + raise NotImplementedError + + def predict_dataloader(self): + raise NotImplementedError + + +_loader = DataLoader(RandomDataset(32, 64)) +_loader_no_len = CustomNotImplementedErrorDataloader(_loader) + + +@pytest.mark.parametrize( + "train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders", + [ + (_loader_no_len, None, None, None), + (None, _loader_no_len, None, None), + (None, None, _loader_no_len, None), + (None, None, None, _loader_no_len), + (None, [_loader, _loader_no_len], None, None), + ], +) +def test_error_patched_iterable_dataloaders( + tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders +): + model = BoringModelNoDataloaders() + connector = DataConnector(MagicMock()) + + connector.attach_dataloaders( + model, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + test_dataloaders=test_dataloaders, + predict_dataloaders=predict_dataloaders, + ) + + with pytest.raises(MisconfigurationException, match="TPUSpawn does not currently support"): + TPUSpawnPlugin(MagicMock()).connect(model) + + +def test_error_process_iterable_dataloader(tmpdir): + with pytest.raises(MisconfigurationException, match="TPUSpawn does not currently support"): + TPUSpawnPlugin(MagicMock()).process_dataloader(_loader_no_len) From 8da584f63d774216a57d55cd32fe102c0b1639d8 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 8 Apr 2021 09:30:26 +0100 Subject: [PATCH 5/9] Update error message --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 ++- tests/plugins/test_tpu_spawn.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 381ee0e0dd617..4b7793a3e5705 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -62,7 +62,8 @@ def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']): for dataloader in dataloaders: if not has_len(dataloader): raise MisconfigurationException( - "TPUSpawn does not currently support IterableDataset objects, the dataset must implement __len__." + "TPUs do not currently support IterableDataset objects, the dataset must implement __len__." + "HINT: You can mock the length on your dataset to bypass this MisconfigurationException." ) @staticmethod diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index d141f5e43ae64..0a045e7c48be1 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -65,10 +65,10 @@ def test_error_patched_iterable_dataloaders( predict_dataloaders=predict_dataloaders, ) - with pytest.raises(MisconfigurationException, match="TPUSpawn does not currently support"): + with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): TPUSpawnPlugin(MagicMock()).connect(model) def test_error_process_iterable_dataloader(tmpdir): - with pytest.raises(MisconfigurationException, match="TPUSpawn does not currently support"): + with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): TPUSpawnPlugin(MagicMock()).process_dataloader(_loader_no_len) From a30389b2f85a332f0e028b7b7d06049f8494a731 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 8 Apr 2021 09:34:04 +0100 Subject: [PATCH 6/9] Update error message --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 4b7793a3e5705..49711e82f9f5b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -63,7 +63,7 @@ def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']): if not has_len(dataloader): raise MisconfigurationException( "TPUs do not currently support IterableDataset objects, the dataset must implement __len__." - "HINT: You can mock the length on your dataset to bypass this MisconfigurationException." + " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." ) @staticmethod From 0904608a5f619bd7887fe47736963eb752720483 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 8 Apr 2021 13:02:32 +0100 Subject: [PATCH 7/9] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- tests/plugins/test_tpu_spawn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 49711e82f9f5b..1ba6e3ff99108 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -62,7 +62,7 @@ def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']): for dataloader in dataloaders: if not has_len(dataloader): raise MisconfigurationException( - "TPUs do not currently support IterableDataset objects, the dataset must implement __len__." + "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." ) diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index 0a045e7c48be1..bb587827c3a3f 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -52,7 +52,7 @@ def predict_dataloader(self): ], ) def test_error_patched_iterable_dataloaders( - tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders + tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders ): model = BoringModelNoDataloaders() connector = DataConnector(MagicMock()) From bb013ca429f85545ad2ed4dbfa3b024d453d6fe3 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 8 Apr 2021 13:05:09 +0100 Subject: [PATCH 8/9] Add hasattr checks --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1ba6e3ff99108..a1b0a45c9e549 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -70,16 +70,16 @@ def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']): def _validate_patched_dataloaders(model: 'Module') -> None: """Validate and fail fast if the dataloaders were passed directly to fit. """ - if isinstance(model.train_dataloader, _PatchDataLoader): + if hasattr(model, 'train_dataloader') and isinstance(model.train_dataloader, _PatchDataLoader): TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader) - if isinstance(model.val_dataloader, _PatchDataLoader): + if hasattr(model, 'val_dataloader') and isinstance(model.val_dataloader, _PatchDataLoader): TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader) - if isinstance(model.test_dataloader, _PatchDataLoader): + if hasattr(model, 'test_dataloader') and isinstance(model.test_dataloader, _PatchDataLoader): TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader) - if isinstance(model.predict_dataloader, _PatchDataLoader): + if hasattr(model, 'predict_dataloader') and isinstance(model.predict_dataloader, _PatchDataLoader): TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader) def connect(self, model: 'Module') -> None: From 0e368e7bb962f58006825d37e27f22e622b58fd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 8 Apr 2021 14:19:15 +0200 Subject: [PATCH 9/9] Update pytorch_lightning/plugins/training_type/tpu_spawn.py --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a1b0a45c9e549..d546067e88a1c 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -86,7 +86,7 @@ def connect(self, model: 'Module') -> None: TPUSpawnPlugin._validate_patched_dataloaders(model) return super().connect(model) - def setup(self, model: torch.nn.Module) -> torch.nn.Module: + def setup(self, model: 'Module') -> 'Module': self.create_mp_queue() return self.model