Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/source-pytorch/accelerators/ipu_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ Currently there are some known limitations that are being addressed in the near

Please see the `MNIST example <https://github.com/Lightning-AI/lightning/blob/master/examples/pl_ipu/mnist_sample.py>`__ which displays most of the limitations and how to overcome them till they are resolved.

* ``self.log`` is not supported in the ``training_step``, ``validation_step``, ``test_step`` or ``predict_step``. This is due to the step function being traced and sent to the IPU devices. We're actively working on fixing this
* Multiple optimizers are not supported. ``training_step`` only supports returning one loss from the ``training_step`` function as a result
* Since the step functions are traced, branching logic or any form of primitive values are traced into constants. Be mindful as this could lead to errors in your custom code
* Clipping gradients is not supported
* ``self.log`` is not supported in the ``training_step``, ``validation_step``, ``test_step`` or ``predict_step``. This is due to the step function being traced and sent to the IPU devices. We're actively working on fixing this.
* Multiple optimizers are not supported. ``training_step`` only supports returning one loss from the ``training_step`` function as a result.
* Since the step functions are traced, branching logic or any form of primitive values are traced into constants. Be mindful as this could lead to errors in your custom code.
* Clipping gradients is not supported.
* It is not possible to use :class:`torch.utils.data.BatchSampler` in your dataloaders if you are using multiple IPUs.
2 changes: 1 addition & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Updated Habana Accelerator's `auto_device_count`, `is_available` & `get_device_name` methods based on the latest torch habana package ([#13423](https://github.com/PyTorchLightning/pytorch-lightning/pull/13423))


-
- Disallowed using `BatchSampler` when running on multiple IPUs ([#13854](https://github.com/PyTorchLightning/pytorch-lightning/pull/13854))


### Deprecated
Expand Down
6 changes: 5 additions & 1 deletion src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
if self.lightning_module.trainer.enable_validation:
model = poptorch.inferenceModel(model=model, options=inference_opts)
self.poptorch_models[RunningStage.VALIDATING] = model
if self.lightning_module.trainer.num_sanity_val_steps > 0:
self.poptorch_models[RunningStage.SANITY_CHECKING] = model
elif trainer_fn == TrainerFn.VALIDATING:
model = poptorch.inferenceModel(model=model, options=self.inference_opts)
self.poptorch_models[RunningStage.VALIDATING] = model
Expand Down Expand Up @@ -228,7 +230,9 @@ def _convert_to_poptorch_loader(
# the user is returning the `poptorch.DataLoader` directly, don't change anything.
return dataloader

dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
dataloader, sampler, mode, self.replication_factor > 1
)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
return dataloader
Expand Down
68 changes: 43 additions & 25 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
def _update_dataloader(
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
) -> DataLoader:
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode=mode)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode)
dl_cls = type(dataloader)
try:
dataloader = dl_cls(*dl_args, **dl_kwargs)
Expand All @@ -212,7 +212,10 @@ def _update_dataloader(


def _get_dataloader_init_args_and_kwargs(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
dataloader: DataLoader,
sampler: Optional[Sampler],
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Tuple[Tuple[Any], Dict[str, Any]]:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
Expand Down Expand Up @@ -264,7 +267,7 @@ def _get_dataloader_init_args_and_kwargs(
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None
else:
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode))
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode, disallow_batch_sampler))

required_args = {
p.name
Expand Down Expand Up @@ -309,7 +312,10 @@ def _get_dataloader_init_args_and_kwargs(


def _dataloader_init_kwargs_resolve_sampler(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
dataloader: DataLoader,
sampler: Optional[Sampler],
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Dict[str, Any]:
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
re-instantiation.
Expand All @@ -321,27 +327,39 @@ def _dataloader_init_kwargs_resolve_sampler(
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
batch_sampler = getattr(dataloader, "batch_sampler")
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting):
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

if fault_tolerant_mode.is_automatic:
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
fast_forward_sampler.setup(dataloader_batch_size=1)

return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}

if batch_sampler is not None:
if disallow_batch_sampler:
# Check that we don't have a PyTorch default batch sampler that was instantiated in DataLoader __init__
if not (
type(batch_sampler) is BatchSampler
and batch_sampler.sampler == sampler
and dataloader.batch_size == batch_sampler.batch_size
):
raise MisconfigurationException(
"It is not possible to have a batch sampler in your dataloader, "
"when running on multiple IPU devices."
)
elif type(batch_sampler) is not BatchSampler or is_predicting:
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

if fault_tolerant_mode.is_automatic:
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
fast_forward_sampler.setup(dataloader_batch_size=1)

return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}

if fault_tolerant_mode.is_automatic:
fast_forward_sampler = sampler = FastForwardSampler(sampler)
Expand Down
6 changes: 5 additions & 1 deletion tests/tests_pytorch/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,11 @@ def test_poptorch_models_at_different_stages(tmpdir):
trainer.optimizers = model.configure_optimizers()[0]
trainer.state.fn = TrainerFn.FITTING
trainer.strategy.setup(trainer)
assert list(trainer.strategy.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING]
assert list(trainer.strategy.poptorch_models) == [
RunningStage.TRAINING,
RunningStage.VALIDATING,
RunningStage.SANITY_CHECKING,
]

for fn, stage in (
(TrainerFn.VALIDATING, RunningStage.VALIDATING),
Expand Down
20 changes: 19 additions & 1 deletion tests/tests_pytorch/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import pytest
import torch
from torch import Tensor
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import BatchSampler, DataLoader, SequentialSampler

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.data import (
_dataloader_init_kwargs_resolve_sampler,
_get_dataloader_init_args_and_kwargs,
_replace_dataloader_init_method,
_update_dataloader,
Expand Down Expand Up @@ -331,6 +332,23 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c
assert getattr(dataloader, key) == value


def test_dataloader_disallow_batch_sampler():
dataset = RandomDataset(5, 100)
dataloader = DataLoader(dataset, batch_size=10)

# This should not raise
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)

dataset = RandomDataset(5, 100)
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size=10, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)

# this should raise - using batch sampler, that was not automatically instantiated by DataLoader
with pytest.raises(MisconfigurationException, match="when running on multiple IPU devices"):
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)


@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
Expand Down