From 63356c983c44fd36c35a163842ab86bb98cae4b4 Mon Sep 17 00:00:00 2001 From: gkroiz Date: Tue, 11 Apr 2023 16:30:54 +0000 Subject: [PATCH 1/4] removing error check on TPUs for IterableDatasets --- src/lightning/fabric/strategies/xla.py | 9 --------- src/lightning/pytorch/strategies/xla.py | 9 --------- 2 files changed, 18 deletions(-) diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index 6f40ef4ada50c..ca67923b91680 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -105,7 +105,6 @@ def module_to_device(self, module: Module) -> None: module.to(self.root_device) def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": - XLAStrategy._validate_dataloader(dataloader) from torch_xla.distributed.parallel_loader import MpDeviceLoader if isinstance(dataloader, MpDeviceLoader): @@ -210,11 +209,3 @@ def _set_world_ranks(self) -> None: if self.cluster_environment is None: return rank_zero_only.rank = self.cluster_environment.global_rank() - - @staticmethod - def _validate_dataloader(dataloader: object) -> None: - if not has_len(dataloader): - raise TypeError( - "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." - " HINT: You can mock the length on your dataset to bypass this error." - ) diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 1d73301ac359f..490d515e0fcf6 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -97,14 +97,6 @@ def root_device(self) -> torch.device: def local_rank(self) -> int: return self.cluster_environment.local_rank() if self.cluster_environment is not None else 0 - @staticmethod - def _validate_dataloader(dataloader: object) -> None: - if not has_len(dataloader): - raise TypeError( - "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." - " HINT: You can mock the length on your dataset to bypass this error." - ) - def connect(self, model: "pl.LightningModule") -> None: import torch_xla.distributed.xla_multiprocessing as xmp @@ -147,7 +139,6 @@ def is_distributed(self) -> bool: return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1 def process_dataloader(self, dataloader: object) -> "MpDeviceLoader": - XLAStrategy._validate_dataloader(dataloader) from torch_xla.distributed.parallel_loader import MpDeviceLoader if isinstance(dataloader, MpDeviceLoader): From 1fd2aa098d05746e74acfa4cbadcfb9c970a4a95 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Apr 2023 17:15:26 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/strategies/xla.py | 1 - src/lightning/pytorch/strategies/xla.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index ca67923b91680..dd01da578f1e5 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -30,7 +30,6 @@ from lightning.fabric.strategies import ParallelStrategy from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.strategies.strategy import TBroadcast -from lightning.fabric.utilities.data import has_len from lightning.fabric.utilities.rank_zero import rank_zero_only from lightning.fabric.utilities.types import _PATH, ReduceOp diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 490d515e0fcf6..7c9d83d749f51 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -23,7 +23,6 @@ from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE from lightning.fabric.plugins import CheckpointIO, XLACheckpointIO from lightning.fabric.plugins.environments import XLAEnvironment -from lightning.fabric.utilities.data import has_len from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.types import _PATH, ReduceOp from lightning.pytorch.overrides.base import _LightningModuleWrapperBase From d2ff4db72eb6949822972d6a7153acf77921e6d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 11 Apr 2023 19:16:50 +0200 Subject: [PATCH 3/4] Update tests --- tests/tests_fabric/strategies/test_xla.py | 24 ++-------------------- tests/tests_pytorch/strategies/test_xla.py | 13 +----------- 2 files changed, 3 insertions(+), 34 deletions(-) diff --git a/tests/tests_fabric/strategies/test_xla.py b/tests/tests_fabric/strategies/test_xla.py index 5282eeaf0fe7d..a9eef4cb10b2f 100644 --- a/tests/tests_fabric/strategies/test_xla.py +++ b/tests/tests_fabric/strategies/test_xla.py @@ -14,7 +14,7 @@ import os from functools import partial from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock import pytest import torch @@ -24,8 +24,7 @@ from lightning.fabric.strategies import XLAStrategy from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.utilities.distributed import ReduceOp -from tests_fabric.helpers.dataloaders import CustomNotImplementedErrorDataloader -from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset +from tests_fabric.helpers.models import RandomDataset from tests_fabric.helpers.runif import RunIf @@ -110,25 +109,6 @@ def __instancecheck__(self, instance): assert processed_dataloader.batch_sampler == processed_dataloader._loader.batch_sampler -_loader = DataLoader(RandomDataset(32, 64)) -_iterable_loader = DataLoader(RandomIterableDataset(32, 64)) -_loader_no_len = CustomNotImplementedErrorDataloader(_loader) - - -@RunIf(tpu=True) -@pytest.mark.parametrize("dataloader", [None, _iterable_loader, _loader_no_len]) -@mock.patch("lightning.fabric.strategies.xla.XLAStrategy.root_device") -def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatch): - """Test that the XLAStrategy validates against dataloaders with no length defined on datasets (iterable - dataset).""" - import torch_xla.distributed.parallel_loader as parallel_loader - - monkeypatch.setattr(parallel_loader, "MpDeviceLoader", Mock()) - - with pytest.raises(TypeError, match="TPUs do not currently support"): - XLAStrategy().process_dataloader(dataloader) - - def tpu_all_gather_fn(strategy): for sync_grads in [True, False]: tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True) diff --git a/tests/tests_pytorch/strategies/test_xla.py b/tests/tests_pytorch/strategies/test_xla.py index 83add3849c45d..81ebc307e3da8 100644 --- a/tests/tests_pytorch/strategies/test_xla.py +++ b/tests/tests_pytorch/strategies/test_xla.py @@ -13,26 +13,15 @@ # limitations under the License. import os from unittest import mock -from unittest.mock import MagicMock -import pytest import torch -from torch.utils.data import DataLoader from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset +from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import XLAStrategy -from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf -def test_error_process_iterable_dataloader(xla_available): - strategy = XLAStrategy(MagicMock()) - loader_no_len = CustomNotImplementedErrorDataloader(DataLoader(RandomDataset(32, 64))) - with pytest.raises(TypeError, match="TPUs do not currently support"): - strategy.process_dataloader(loader_no_len) - - class BoringModelTPU(BoringModel): def on_train_start(self) -> None: # assert strategy attributes for device setting From d852d22ca21bfe5a6817ae4642637fc2a7919fe3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 11 Apr 2023 19:18:43 +0200 Subject: [PATCH 4/4] CHANGELOG --- src/lightning/fabric/CHANGELOG.md | 3 +-- src/lightning/pytorch/CHANGELOG.md | 4 ++++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index e4de85c799f41..9c71a6572f2f3 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -17,8 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Generalized `Optimizer` validation to accommodate both FSDP 1.x and 2.x ([#16733](https://github.com/Lightning-AI/lightning/pull/16733)) -- - +- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) ### Depercated diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a0926a0b565e8..203bbf0517f2d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -27,8 +27,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disable `torch.inference_mode` with `torch.compile` in PyTorch 2.0 ([#17215](https://github.com/Lightning-AI/lightning/pull/17215)) + - Changed the `is_picklable` util function to handle the edge case that throws a `TypeError` ([#17270](https://github.com/Lightning-AI/lightning/pull/17270)) + +- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) + ### Depercated -