Skip to content

Commit 374fae5

Browse files
tchatonjustusschockawaelchlipre-commit-ci[bot]carmocca
authored
[Feat] Add utilities for CombinedLoader state dict and dataloader state dict 1/n (#8364)
Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 257fabd commit 374fae5

File tree

8 files changed

+603
-87
lines changed

8 files changed

+603
-87
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9595
* Set `Loop.restarting=False` at the end of the first iteration ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
9696
* Save the loops state with the checkpoint (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
9797
* Save a checkpoint to restore the state on exception (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
98+
* Added `state_dict` and `load_state_dict` utilities for `CombinedLoader` + utilities for dataloader ([#8364](https://github.com/PyTorchLightning/pytorch-lightning/pull/8364))
9899

99100

100101
- Added `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))

pytorch_lightning/trainer/data_loading.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
from pytorch_lightning.trainer.supporters import CombinedLoader
3131
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
3232
from pytorch_lightning.utilities.apply_func import apply_to_collection
33+
from pytorch_lightning.utilities.auto_restart import _sampler_metadata_collate
3334
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
3435
from pytorch_lightning.utilities.debugging import InternalDebugger
3536
from pytorch_lightning.utilities.exceptions import MisconfigurationException
37+
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
3638
from pytorch_lightning.utilities.model_helpers import is_overridden
3739
from pytorch_lightning.utilities.seed import pl_worker_init_function
3840

@@ -259,6 +261,10 @@ def reset_train_dataloader(self, model: 'pl.LightningModule') -> None:
259261
# add worker_init_fn for correct seeding in worker processes
260262
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
261263

264+
# add collate_fn to collect metadata for fault tolerant training
265+
if _fault_tolerant_enabled():
266+
apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)
267+
262268
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
263269
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)
264270

@@ -460,9 +466,6 @@ def reset_train_val_dataloaders(self, model) -> None:
460466
def request_dataloader(self, model: 'pl.LightningModule', stage: str) -> DataLoader:
461467
"""Handles downloading data in the GPU or TPU case.
462468
463-
Args:
464-
dataloader_fx: The bound dataloader getter
465-
466469
Returns:
467470
The dataloader
468471
"""
@@ -474,11 +477,16 @@ def request_dataloader(self, model: 'pl.LightningModule', stage: str) -> DataLoa
474477

475478
def _flatten_dl_only(self, dataloaders):
476479
# handles user error when they return:
477-
# return dl1, dl2 vs return (dl1, dl2)
478-
if isinstance(dataloaders, tuple):
479-
all_dls = [isinstance(x, Iterable) for x in dataloaders]
480-
all_dls = all(all_dls)
481-
if all_dls:
482-
dataloaders = list(dataloaders)
483-
480+
# `return dl1, dl2` vs `return (dl1, dl2)`
481+
if isinstance(dataloaders, tuple) and all(isinstance(x, Iterable) for x in dataloaders):
482+
return list(dataloaders)
484483
return dataloaders
484+
485+
@staticmethod
486+
def _add_sampler_metadata_collate(dataloader: DataLoader) -> None:
487+
"""
488+
Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is enabled.
489+
"""
490+
dataloader.collate_fn = partial(
491+
_sampler_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
492+
)

pytorch_lightning/trainer/supporters.py

Lines changed: 121 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,25 @@
1414

1515
import os
1616
from collections.abc import Iterable, Iterator, Mapping, Sequence
17-
from typing import Any, Callable, Generator, Optional, Tuple, Union
17+
from functools import partial
18+
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
1819

1920
import torch
2021
from torch import Tensor
2122
from torch.utils.data import Dataset
22-
from torch.utils.data.dataloader import DataLoader
23+
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
2324
from torch.utils.data.dataset import IterableDataset
2425

25-
from pytorch_lightning.utilities.apply_func import apply_to_collection
26+
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
27+
from pytorch_lightning.utilities.auto_restart import (
28+
_cycle_to_next_worker_and_reset,
29+
_find_current_worker,
30+
CaptureIterableDataset,
31+
)
2632
from pytorch_lightning.utilities.cloud_io import get_filesystem
2733
from pytorch_lightning.utilities.data import get_len
2834
from pytorch_lightning.utilities.exceptions import MisconfigurationException
35+
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
2936

3037

3138
class TensorRunningAccum(object):
@@ -172,12 +179,10 @@ class CycleIterator(object):
172179

173180
def __init__(self, loader: Any, length: Optional[int] = None):
174181
"""
175-
176182
Args:
177183
loader: the loader to restart for cyclic (and optionally infinite) sampling
178184
length: the number of batches to sample (with restarted loaders if necessary) before raising StopIteration
179185
if None: infinite
180-
181186
"""
182187
if length is None:
183188
length = float('inf')
@@ -193,7 +198,6 @@ def __iter__(self) -> Any:
193198
194199
Returns:
195200
CycleIterator: self
196-
197201
"""
198202
self.counter = 0
199203
self._loader_iter = iter(self.loader)
@@ -209,7 +213,6 @@ def __next__(self) -> Any:
209213
210214
Raises:
211215
StopIteration: if more then :attr:`length` batches have been returned
212-
213216
"""
214217
# Note: if self.length is `inf`, then the iterator will never stop
215218
if self.counter >= self.__len__():
@@ -237,13 +240,11 @@ class CombinedDataset(object):
237240

238241
def __init__(self, datasets: Union[Sequence, Mapping], mode: str = 'min_size'):
239242
"""
240-
241243
Args:
242244
datasets: a sequence/mapping datasets. Can be a collections of torch.utils.Dataset,
243245
Iterable or even None.
244246
mode: whether to use the minimum number of batches in all samples or the maximum
245247
number of batches in all samples.
246-
247248
"""
248249
self.datasets = datasets
249250
if mode not in self.COMPUTE_FUNCS.keys():
@@ -273,7 +274,6 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union
273274
274275
Returns:
275276
length: the length of `CombinedDataset`
276-
277277
"""
278278
if mode not in CombinedDataset.COMPUTE_FUNCS.keys():
279279
raise MisconfigurationException(f"Invalid Mode: {mode}")
@@ -319,10 +319,14 @@ def __len__(self) -> int:
319319
return self._calc_num_data(self.datasets, self.mode)
320320

321321

322+
class DataLoaderDict(Dict):
323+
# behaves exactly like a dict, this is used to simplify apply_to_collection.
324+
pass
325+
326+
322327
class CombinedLoader(object):
323328
"""
324329
Combines different dataloaders and allows sampling in parallel.
325-
326330
Supported modes are 'min_size', which raises StopIteration after the shortest loader
327331
(the one with the lowest number of batches) is done, and 'max_size_cycle` which raises
328332
StopIteration after the longest loader (the one with most batches) is done, while cycling
@@ -342,18 +346,15 @@ class CombinedLoader(object):
342346
... print(item)
343347
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
344348
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
345-
346349
"""
347350
SUPPORTED_MODES = ('min_size', 'max_size_cycle')
348351

349352
def __init__(self, loaders: Any, mode: str = 'min_size'):
350353
"""
351-
352354
Args:
353355
loaders: the loaders to sample from. Can be all kind of collection
354356
mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and
355357
'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones.
356-
357358
"""
358359
if mode not in self.SUPPORTED_MODES:
359360
raise MisconfigurationException(f"Invalid Mode: {mode}")
@@ -371,6 +372,84 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
371372
if self.mode == 'max_size_cycle':
372373
self._wrap_loaders_max_size_cycle()
373374

375+
self._loaders_iter_state_dict = None
376+
self._iterator = None # assigned in __iter__
377+
378+
@staticmethod
379+
def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], num_batches_processed: int) -> Dict:
380+
# find next worker if multiple workers were used
381+
state = _find_current_worker(iterator)
382+
if isinstance(dataloader.dataset, CaptureIterableDataset):
383+
# the sampler state dict are extracted in `CombinedLoaderIterator`
384+
if iterator is not None and getattr(iterator, "_sampler_state_dict", None) is not None:
385+
state.update(iterator._sampler_state_dict[0])
386+
else:
387+
# fetch directly from fast forward sampler
388+
state.update(dataloader.fast_forward_sampler.state_dict(num_batches_processed))
389+
return DataLoaderDict(state)
390+
391+
def state_dict(self, num_batches_processed: int) -> Dict:
392+
"""
393+
The state dict includes all states from wrapped dataloaders and their samplers through the
394+
``CaptureIterableDataset`` and fast-forward samplers.
395+
396+
Args:
397+
num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
398+
may have already prefetched more batches by the time a state dict is requested.
399+
"""
400+
if not _fault_tolerant_enabled():
401+
return DataLoaderDict()
402+
403+
state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)
404+
405+
return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)
406+
407+
def load_state_dict(self, state_dict):
408+
# store the samplers state.
409+
# They would be reloaded once the ``CombinedIterator`` as been created
410+
# and the workers are created.
411+
self._loaders_iter_state_dict = state_dict
412+
413+
def mock_reset_fn(self, *_, **__):
414+
pass
415+
416+
# mock reset call, so we can rotate the ``_worker_queue_idx_cycle`` to failed worker
417+
# and get the first batch from it
418+
_MultiProcessingDataLoaderIter._original_reset = _MultiProcessingDataLoaderIter._reset
419+
_MultiProcessingDataLoaderIter._reset = mock_reset_fn
420+
421+
def on_restart(self, iterator: Iterator):
422+
if not self._loaders_iter_state_dict:
423+
return
424+
425+
# this happen inside the workers if any were specificied.
426+
427+
def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict):
428+
if isinstance(dataloader.dataset, CaptureIterableDataset):
429+
# provide the ``state_dict`` to the ``CaptureIterableDataset``
430+
# as it is responsible for passing down the state to associated ``FastForwardSampler``
431+
dataloader.dataset.load_state_dict(state_dict)
432+
else:
433+
# for ``Mapping-based`` dataset, the ``fast_forward_sampler`` was attached
434+
# on the dataloader for simplicity
435+
dataloader.fast_forward_sampler.load_state_dict(state_dict)
436+
437+
# cycle back the iterator to the failed worker if multiple workers were provided
438+
iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict)
439+
440+
if isinstance(dataloader.dataset, CaptureIterableDataset):
441+
# remove keys related to iterator
442+
state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")}
443+
# need to re-attach the state dict into the iterator for future collection.
444+
iterator._sampler_state_dict = [state_dict]
445+
return iterator
446+
447+
# apply the ``create_loader_iters`` on the collection of ``DataLoader / Iterator``.
448+
# each ``Iterator``` was created from the ``DataLoader``.
449+
iterator._loader_iters = apply_to_collections(
450+
self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters
451+
)
452+
374453
@property
375454
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
376455
"""Return a collections of samplers extracting from loaders."""
@@ -382,7 +461,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
382461
383462
Returns:
384463
the wrapped loaders
385-
386464
"""
387465
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
388466

@@ -398,7 +476,18 @@ def __iter__(self) -> Any:
398476
"""
399477
Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.
400478
"""
401-
return CombinedLoaderIterator(self.loaders)
479+
480+
# prevent ``NotImplementedError`` from PyTorch:
481+
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
482+
def __getstate__patch__(*_):
483+
return {}
484+
485+
_BaseDataLoaderIter.__getstate__ = __getstate__patch__
486+
iterator = CombinedLoaderIterator(self.loaders)
487+
# handle fault tolerant restart logic.
488+
self.on_restart(iterator)
489+
self._iterator = iterator
490+
return iterator
402491

403492
@staticmethod
404493
def _calc_num_batches(loaders: Any) -> Union[int, float]:
@@ -410,7 +499,6 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]:
410499
411500
Returns:
412501
length: the minimum length of loaders
413-
414502
"""
415503
all_lengths = apply_to_collection(loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
416504

@@ -429,10 +517,8 @@ class CombinedLoaderIterator(object):
429517

430518
def __init__(self, loaders: Any):
431519
"""
432-
433520
Args:
434521
loaders: the loaders to sample from. Can be all kind of collection
435-
436522
"""
437523
self.loaders = loaders
438524
self._loader_iters = None
@@ -456,7 +542,6 @@ def __next__(self) -> Any:
456542
457543
Returns:
458544
a collections of batch data
459-
460545
"""
461546
return self.request_next_batch(self.loader_iters)
462547

@@ -470,9 +555,23 @@ def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any:
470555
471556
Returns
472557
Any: a collections of batch data
473-
474558
"""
475-
return apply_to_collection(loader_iters, Iterator, next)
559+
560+
def next_fn(iterator: Iterator):
561+
batch = next(iterator)
562+
if not _fault_tolerant_enabled():
563+
return batch
564+
# when fault tolerant is enabled, the iterator will return
565+
# ``FastForwardSampler`` state_dict metadata
566+
# along side with the user data.
567+
# the metadata are extracted and store directly on the iterator
568+
# to simplify the collection on ``state_dict`` call.
569+
batch, samplers_state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch)
570+
# store the ``sampler_state_dict`` on the iterator
571+
CaptureIterableDataset.store_samplers_state_dict(iterator, samplers_state_dict)
572+
return batch
573+
574+
return apply_to_collection(loader_iters, Iterator, next_fn)
476575

477576
@staticmethod
478577
def create_loader_iters(
@@ -486,7 +585,6 @@ def create_loader_iters(
486585
487586
Returns
488587
a collections of iterators
489-
490588
"""
491589
# dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences
492590
return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))

0 commit comments

Comments
 (0)