Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Generalized `Optimizer` validation to accommodate both FSDP 1.x and 2.x ([#16733](https://github.com/Lightning-AI/lightning/pull/16733))


-

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))

### Depercated

Expand Down
10 changes: 0 additions & 10 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from lightning.fabric.strategies import ParallelStrategy
from lightning.fabric.strategies.launchers.xla import _XLALauncher
from lightning.fabric.strategies.strategy import TBroadcast
from lightning.fabric.utilities.data import has_len
from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.types import _PATH, ReduceOp

Expand Down Expand Up @@ -105,7 +104,6 @@ def module_to_device(self, module: Module) -> None:
module.to(self.root_device)

def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader":
XLAStrategy._validate_dataloader(dataloader)
from torch_xla.distributed.parallel_loader import MpDeviceLoader

if isinstance(dataloader, MpDeviceLoader):
Expand Down Expand Up @@ -210,11 +208,3 @@ def _set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
rank_zero_only.rank = self.cluster_environment.global_rank()

@staticmethod
def _validate_dataloader(dataloader: object) -> None:
if not has_len(dataloader):
raise TypeError(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this error."
)
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Disable `torch.inference_mode` with `torch.compile` in PyTorch 2.0 ([#17215](https://github.com/Lightning-AI/lightning/pull/17215))


- Changed the `is_picklable` util function to handle the edge case that throws a `TypeError` ([#17270](https://github.com/Lightning-AI/lightning/pull/17270))


- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))

### Depercated

-
Expand Down
10 changes: 0 additions & 10 deletions src/lightning/pytorch/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE
from lightning.fabric.plugins import CheckpointIO, XLACheckpointIO
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.utilities.data import has_len
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.types import _PATH, ReduceOp
from lightning.pytorch.overrides.base import _LightningModuleWrapperBase
Expand Down Expand Up @@ -97,14 +96,6 @@ def root_device(self) -> torch.device:
def local_rank(self) -> int:
return self.cluster_environment.local_rank() if self.cluster_environment is not None else 0

@staticmethod
def _validate_dataloader(dataloader: object) -> None:
if not has_len(dataloader):
raise TypeError(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this error."
)

def connect(self, model: "pl.LightningModule") -> None:
import torch_xla.distributed.xla_multiprocessing as xmp

Expand Down Expand Up @@ -147,7 +138,6 @@ def is_distributed(self) -> bool:
return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1

def process_dataloader(self, dataloader: object) -> "MpDeviceLoader":
XLAStrategy._validate_dataloader(dataloader)
from torch_xla.distributed.parallel_loader import MpDeviceLoader

if isinstance(dataloader, MpDeviceLoader):
Expand Down
24 changes: 2 additions & 22 deletions tests/tests_fabric/strategies/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from functools import partial
from unittest import mock
from unittest.mock import MagicMock, Mock
from unittest.mock import MagicMock

import pytest
import torch
Expand All @@ -24,8 +24,7 @@
from lightning.fabric.strategies import XLAStrategy
from lightning.fabric.strategies.launchers.xla import _XLALauncher
from lightning.fabric.utilities.distributed import ReduceOp
from tests_fabric.helpers.dataloaders import CustomNotImplementedErrorDataloader
from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset
from tests_fabric.helpers.models import RandomDataset
from tests_fabric.helpers.runif import RunIf


Expand Down Expand Up @@ -110,25 +109,6 @@ def __instancecheck__(self, instance):
assert processed_dataloader.batch_sampler == processed_dataloader._loader.batch_sampler


_loader = DataLoader(RandomDataset(32, 64))
_iterable_loader = DataLoader(RandomIterableDataset(32, 64))
_loader_no_len = CustomNotImplementedErrorDataloader(_loader)


@RunIf(tpu=True)
@pytest.mark.parametrize("dataloader", [None, _iterable_loader, _loader_no_len])
@mock.patch("lightning.fabric.strategies.xla.XLAStrategy.root_device")
def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatch):
"""Test that the XLAStrategy validates against dataloaders with no length defined on datasets (iterable
dataset)."""
import torch_xla.distributed.parallel_loader as parallel_loader

monkeypatch.setattr(parallel_loader, "MpDeviceLoader", Mock())

with pytest.raises(TypeError, match="TPUs do not currently support"):
XLAStrategy().process_dataloader(dataloader)


def tpu_all_gather_fn(strategy):
for sync_grads in [True, False]:
tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True)
Expand Down
13 changes: 1 addition & 12 deletions tests/tests_pytorch/strategies/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,15 @@
# limitations under the License.
import os
from unittest import mock
from unittest.mock import MagicMock

import pytest
import torch
from torch.utils.data import DataLoader

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.strategies import XLAStrategy
from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader
from tests_pytorch.helpers.runif import RunIf


def test_error_process_iterable_dataloader(xla_available):
strategy = XLAStrategy(MagicMock())
loader_no_len = CustomNotImplementedErrorDataloader(DataLoader(RandomDataset(32, 64)))
with pytest.raises(TypeError, match="TPUs do not currently support"):
strategy.process_dataloader(loader_no_len)


class BoringModelTPU(BoringModel):
def on_train_start(self) -> None:
# assert strategy attributes for device setting
Expand Down