Skip to content

Commit b13749b

Browse files
awaelchlitchatonkaushikb11carmoccaBorda
authored
add fault-tolerance for global random state in map-style datasets (#8950)
Co-authored-by: tchaton <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 0752bcd commit b13749b

File tree

7 files changed

+402
-72
lines changed

7 files changed

+402
-72
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5454
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
5555
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
5656
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))
57+
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
5758

5859
- Checkpoint saving & loading extensibility:
5960
* Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))

pytorch_lightning/loops/fit_loop.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
from contextlib import suppress
17-
from typing import Optional
17+
from typing import Any, Dict, Optional
1818

1919
from pytorch_lightning.loops import Loop
2020
from pytorch_lightning.loops.epoch import TrainingEpochLoop
@@ -40,6 +40,8 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] =
4040
self.min_epochs = min_epochs
4141
self.epoch_loop: Optional[TrainingEpochLoop] = None
4242
self.epoch_progress = Progress()
43+
# caches the loaded dataloader state until dataloader objects are available
44+
self._dataloader_state_dict: Dict[str, Any] = {}
4345

4446
@property
4547
def current_epoch(self) -> int:
@@ -175,6 +177,10 @@ def on_advance_start(self) -> None:
175177
if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch:
176178
self.trainer.reset_train_dataloader(model)
177179

180+
if self._dataloader_state_dict:
181+
self.trainer.train_dataloader.load_state_dict(self._dataloader_state_dict)
182+
self._dataloader_state_dict = {}
183+
178184
# TODO: specify the possible exception
179185
with suppress(Exception):
180186
# set seed for distributed sampler (enables shuffling for each epoch)
@@ -234,3 +240,13 @@ def should_accumulate(self) -> bool:
234240

235241
def teardown(self) -> None:
236242
self.epoch_loop.teardown()
243+
244+
def on_save_checkpoint(self) -> Dict:
245+
state_dict = super().on_save_checkpoint()
246+
# FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ?
247+
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
248+
return state_dict
249+
250+
def on_load_checkpoint(self, state_dict: Dict) -> None:
251+
# cache the dataloader state dict until the dataloader objects are available
252+
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})

pytorch_lightning/trainer/supporters.py

Lines changed: 93 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,23 @@
1313
# limitations under the License.
1414

1515
from collections.abc import Iterable, Iterator, Mapping, Sequence
16-
from dataclasses import dataclass, field
16+
from dataclasses import asdict, dataclass, field
1717
from functools import partial
1818
from typing import Any, Callable, Dict, List, Optional, Union
1919

2020
import torch
2121
from torch.utils.data import Dataset
22-
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
22+
from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader
2323
from torch.utils.data.dataset import IterableDataset
2424

2525
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
2626
from pytorch_lightning.utilities.auto_restart import (
27-
_cycle_to_next_worker_and_reset,
28-
_find_current_worker,
27+
_find_fast_forward_samplers,
2928
CaptureIterableDataset,
29+
CaptureMapDataset,
30+
IteratorState,
31+
MergedIteratorState,
32+
patch_dataloader_iterator,
3033
)
3134
from pytorch_lightning.utilities.data import get_len
3235
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -167,6 +170,7 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle
167170
self.loader = loader
168171
self._loader_iter = None
169172
self.counter = 0
173+
self.state = state
170174

171175
def __iter__(self) -> Any:
172176
"""
@@ -176,6 +180,7 @@ def __iter__(self) -> Any:
176180
CycleIterator: self
177181
"""
178182
self.counter = 0
183+
self.state.reset()
179184
self._loader_iter = iter(self.loader)
180185
return self
181186

@@ -205,6 +210,12 @@ def __next__(self) -> Any:
205210
raise StopIteration
206211

207212
self._loader_iter = iter(self.loader)
213+
# if fault tolerant is enabled, we need to patch the iterator to collect the states
214+
# before the batch gets returned.
215+
fetcher = getattr(self.loader, "_lightning_fetcher", None)
216+
if fetcher:
217+
patch_dataloader_iterator(self.loader, self._loader_iter, fetcher)
218+
208219
return next(self._loader_iter)
209220

210221
finally:
@@ -302,11 +313,6 @@ def __len__(self) -> int:
302313
return self._calc_num_data(self.datasets, self.mode)
303314

304315

305-
class DataLoaderDict(Dict):
306-
# behaves exactly like a dict, this is used to simplify apply_to_collection.
307-
pass
308-
309-
310316
class CombinedLoader:
311317
"""
312318
Combines different dataloaders and allows sampling in parallel.
@@ -360,80 +366,110 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
360366
self._iterator = None # assigned in __iter__
361367

362368
@staticmethod
363-
def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], num_batches_processed: int) -> Dict:
364-
# find next worker if multiple workers were used
365-
state = _find_current_worker(iterator)
366-
if isinstance(dataloader.dataset, CaptureIterableDataset):
367-
# the sampler state dict are extracted in `CombinedLoaderIterator`
368-
if iterator is not None and getattr(iterator, "_sampler_state_dict", None) is not None:
369-
state.update(iterator._sampler_state_dict[0])
370-
else:
371-
# fetch directly from fast forward sampler
372-
state.update(dataloader.fast_forward_sampler.state_dict(num_batches_processed))
373-
return DataLoaderDict(state)
374-
375-
def state_dict(self, num_batches_processed: int) -> Dict:
369+
def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_completed: int) -> Dict:
370+
if isinstance(dataloader, CycleIterator):
371+
iterator = dataloader._loader_iter
372+
state = getattr(iterator, "state", None) if has_completed else getattr(iterator, "previous_state", None)
373+
if state:
374+
return asdict(state)
375+
return {}
376+
377+
def state_dict(self, has_completed: bool = False) -> Dict:
376378
"""
377379
The state dict includes all states from wrapped dataloaders and their samplers through the
378380
``CaptureIterableDataset`` and fast-forward samplers.
379381
380382
Args:
381-
num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
382-
may have already prefetched more batches by the time a state dict is requested.
383+
has_completed: whether the current state of data fetching is considered completed or not. If it is, the
384+
current state gets returned, otherwise the previously cached state.
383385
"""
384-
if not _fault_tolerant_training():
385-
return DataLoaderDict()
386-
387-
state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)
386+
if not _fault_tolerant_training() or self._iterator is None:
387+
return {}
388388

389-
return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)
389+
return apply_to_collections(
390+
self.loaders,
391+
self._iterator.loader_iters,
392+
(Iterator, DataLoader),
393+
partial(self._state_dict_fn, has_completed=has_completed),
394+
)
390395

391-
def load_state_dict(self, state_dict):
396+
def load_state_dict(self, state_dict) -> None:
392397
# store the samplers state.
393398
# They would be reloaded once the `CombinedIterator` as been created
394399
# and the workers are created.
395400
self._loaders_iter_state_dict = state_dict
396401

397-
def mock_reset_fn(self, *_, **__):
398-
pass
399-
400-
# mock reset call, so we can rotate the `_worker_queue_idx_cycle` to failed worker
401-
# and get the first batch from it
402-
_MultiProcessingDataLoaderIter._original_reset = _MultiProcessingDataLoaderIter._reset
403-
_MultiProcessingDataLoaderIter._reset = mock_reset_fn
404-
405-
def on_restart(self, iterator: Iterator):
402+
def on_restart(self, iterator: Iterator) -> None:
406403
if not self._loaders_iter_state_dict:
407404
return
408405

409-
# this happen inside the workers if any were specificied.
406+
def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator:
407+
"""Function used to reload the iterator state before once the workers are created."""
408+
409+
dataloader_to_iter_on = dataloader
410+
if isinstance(dataloader, CycleIterator):
411+
dataloader = dataloader_to_iter_on.loader
412+
413+
dataset = dataloader.dataset
414+
415+
# We reload the states before creating the workers
416+
# The specific type of dataset will then decide if the state should be applied before or after
417+
# spawning the workers
418+
if isinstance(dataset, CaptureMapDataset):
419+
iterator_state = state_dict["state"][0]
420+
421+
if not isinstance(iterator_state, IteratorState):
422+
iterator_state = IteratorState.from_state_dict(iterator_state)
423+
424+
# reload sampler state
425+
ff_sampler = _find_fast_forward_samplers(dataloader)
426+
ff_sampler.load_state_dict(iterator_state.sampler_state)
427+
# reload dataset state
428+
dataset.load_state_dict(
429+
iterator_state.dataset_state,
430+
latest_worker_id=state_dict["latest_worker_id"],
431+
num_workers=iterator_state.num_workers,
432+
)
433+
434+
elif isinstance(dataset, CaptureIterableDataset):
435+
dataset_dict = {
436+
sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()
437+
}
438+
dataset.load_state_dict(dataset_dict)
410439

411-
def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict):
412-
if isinstance(dataloader.dataset, CaptureIterableDataset):
413-
# provide the `state_dict` to the `CaptureIterableDataset`
414-
# as it is responsible for passing down the state to associated `FastForwardSampler`
415-
dataloader.dataset.load_state_dict(state_dict)
416440
else:
417-
# for `Mapping-based` dataset, the `fast_forward_sampler` was attached
418-
# on the dataloader for simplicity
419-
dataloader.fast_forward_sampler.load_state_dict(state_dict)
441+
raise MisconfigurationException(
442+
"This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
443+
)
444+
445+
# We finally spawned the workers if any.
446+
it = iter(dataloader_to_iter_on)
420447

421-
# cycle back the iterator to the failed worker if multiple workers were provided
422-
iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict)
448+
# restore caching state
449+
state = MergedIteratorState.from_state_dict(state_dict)
423450

424-
if isinstance(dataloader.dataset, CaptureIterableDataset):
425-
# remove keys related to iterator
426-
state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")}
427-
# need to re-attach the state dict into the iterator for future collection.
428-
iterator._sampler_state_dict = [state_dict]
429-
return iterator
451+
if isinstance(dataloader_to_iter_on, CycleIterator):
452+
it._loader_iter.state = state
453+
else:
454+
it.state = state
455+
return it
456+
457+
# create an un-existing token, so it doesn't activate for something else than an iterator.
458+
class DataLoaderDict(dict):
459+
pass
430460

431461
# apply the `create_loader_iters` on the collection of `DataLoader / Iterator`.
432462
# each `Iterator` was created from the `DataLoader`.
433463
iterator._loader_iters = apply_to_collections(
434-
self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters
464+
self.loaders,
465+
self._loaders_iter_state_dict,
466+
(Iterable, DataLoaderDict),
467+
create_loader_iters,
468+
wrong_dtype=(Sequence, Mapping),
435469
)
436470

471+
self._loaders_iter_state_dict = None
472+
437473
@property
438474
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
439475
"""Return a collections of samplers extracting from loaders."""
@@ -457,7 +493,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
457493
self.loaders = apply_to_collection(
458494
self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping)
459495
)
460-
461496
state.reset()
462497

463498
def __iter__(self) -> Any:

pytorch_lightning/utilities/auto_restart.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
from copy import deepcopy
1717
from dataclasses import dataclass, field
1818
from functools import partial, wraps
19+
from random import getstate as python_get_rng_state
20+
from random import setstate as python_set_rng_state
1921
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union
2022

23+
import numpy as np
24+
import torch
2125
from torch.utils.data import Dataset, get_worker_info, Sampler
2226
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
2327

@@ -168,6 +172,16 @@ def update(self, generator_name: Optional[str], new_state: IteratorState) -> Non
168172
state[latest_worker_id] = new_state
169173
self.latest_worker_id = latest_worker_id
170174

175+
@property
176+
def sampler_states(self) -> Dict[int, Any]:
177+
"""Returns the merged sampler states for all worker processes."""
178+
return {0: self.state[k].sampler_state[0] for k in self.state.keys()}
179+
180+
@property
181+
def dataset_states(self) -> Dict[int, Any]:
182+
"""Returns the merged dataset states for all worker processes."""
183+
return {k: self.state[k].dataset_state[k] for k in self.state.keys()}
184+
171185
@classmethod
172186
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
173187
if state_dict["represent_map_dataset"]:
@@ -188,7 +202,12 @@ def __len__(self) -> int:
188202

189203

190204
class CaptureMapDataset(Dataset):
191-
"""This class is used to capture the state from the map-based state dataset."""
205+
"""This class is used to capture the state from the map-based state dataset.
206+
207+
Note:
208+
We currently don't support restoring if we fail during the first `N = num_workers` batches, where
209+
`num_workers` is the number of workers spawned by the dataloader.
210+
"""
192211

193212
def __init__(self, dataset: Dataset) -> None:
194213
self.dataset = dataset
@@ -202,8 +221,7 @@ def worker_id(self) -> int:
202221
def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
203222
if self._cached_state_dict is not None:
204223
if self.worker_id in self._cached_state_dict:
205-
# TODO: reset random states
206-
pass
224+
set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
207225
self._cached_state_dict = None
208226

209227
data = self.dataset[item]
@@ -227,7 +245,19 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num
227245
self._cached_state_dict = state_dict
228246

229247
def _state_dict(self) -> Dict[int, Dict[str, Any]]:
230-
return {self.worker_id: {"rng_states": {}}}
248+
return {self.worker_id: {"rng_states": collect_rng_states()}}
249+
250+
251+
def collect_rng_states() -> Dict[str, Any]:
252+
"""Collect the global random state of :mod:`torch`, :mod:`numpy` and Python."""
253+
return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()}
254+
255+
256+
def set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
257+
"""Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process."""
258+
torch.set_rng_state(rng_state_dict.get("torch"))
259+
np.random.set_state(rng_state_dict.get("numpy"))
260+
python_set_rng_state(rng_state_dict.get("python"))
231261

232262

233263
class CaptureIterableDataset(IterableDataset):

0 commit comments

Comments
 (0)