1414
1515import os
1616from collections .abc import Iterable , Iterator , Mapping , Sequence
17- from typing import Any , Callable , Generator , Optional , Tuple , Union
17+ from functools import partial
18+ from typing import Any , Callable , Dict , Generator , Optional , Tuple , Union
1819
1920import torch
2021from torch import Tensor
2122from torch .utils .data import Dataset
22- from torch .utils .data .dataloader import DataLoader
23+ from torch .utils .data .dataloader import _BaseDataLoaderIter , _MultiProcessingDataLoaderIter , DataLoader
2324from torch .utils .data .dataset import IterableDataset
2425
25- from pytorch_lightning .utilities .apply_func import apply_to_collection
26+ from pytorch_lightning .utilities .apply_func import apply_to_collection , apply_to_collections
27+ from pytorch_lightning .utilities .auto_restart import (
28+ _cycle_to_next_worker_and_reset ,
29+ _find_current_worker ,
30+ CaptureIterableDataset ,
31+ )
2632from pytorch_lightning .utilities .cloud_io import get_filesystem
2733from pytorch_lightning .utilities .data import get_len
2834from pytorch_lightning .utilities .exceptions import MisconfigurationException
35+ from pytorch_lightning .utilities .imports import _fault_tolerant_enabled
2936
3037
3138class TensorRunningAccum (object ):
@@ -172,12 +179,10 @@ class CycleIterator(object):
172179
173180 def __init__ (self , loader : Any , length : Optional [int ] = None ):
174181 """
175-
176182 Args:
177183 loader: the loader to restart for cyclic (and optionally infinite) sampling
178184 length: the number of batches to sample (with restarted loaders if necessary) before raising StopIteration
179185 if None: infinite
180-
181186 """
182187 if length is None :
183188 length = float ('inf' )
@@ -193,7 +198,6 @@ def __iter__(self) -> Any:
193198
194199 Returns:
195200 CycleIterator: self
196-
197201 """
198202 self .counter = 0
199203 self ._loader_iter = iter (self .loader )
@@ -209,7 +213,6 @@ def __next__(self) -> Any:
209213
210214 Raises:
211215 StopIteration: if more then :attr:`length` batches have been returned
212-
213216 """
214217 # Note: if self.length is `inf`, then the iterator will never stop
215218 if self .counter >= self .__len__ ():
@@ -237,13 +240,11 @@ class CombinedDataset(object):
237240
238241 def __init__ (self , datasets : Union [Sequence , Mapping ], mode : str = 'min_size' ):
239242 """
240-
241243 Args:
242244 datasets: a sequence/mapping datasets. Can be a collections of torch.utils.Dataset,
243245 Iterable or even None.
244246 mode: whether to use the minimum number of batches in all samples or the maximum
245247 number of batches in all samples.
246-
247248 """
248249 self .datasets = datasets
249250 if mode not in self .COMPUTE_FUNCS .keys ():
@@ -273,7 +274,6 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union
273274
274275 Returns:
275276 length: the length of `CombinedDataset`
276-
277277 """
278278 if mode not in CombinedDataset .COMPUTE_FUNCS .keys ():
279279 raise MisconfigurationException (f"Invalid Mode: { mode } " )
@@ -319,10 +319,14 @@ def __len__(self) -> int:
319319 return self ._calc_num_data (self .datasets , self .mode )
320320
321321
322+ class DataLoaderDict (Dict ):
323+ # behaves exactly like a dict, this is used to simplify apply_to_collection.
324+ pass
325+
326+
322327class CombinedLoader (object ):
323328 """
324329 Combines different dataloaders and allows sampling in parallel.
325-
326330 Supported modes are 'min_size', which raises StopIteration after the shortest loader
327331 (the one with the lowest number of batches) is done, and 'max_size_cycle` which raises
328332 StopIteration after the longest loader (the one with most batches) is done, while cycling
@@ -342,18 +346,15 @@ class CombinedLoader(object):
342346 ... print(item)
343347 {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
344348 {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
345-
346349 """
347350 SUPPORTED_MODES = ('min_size' , 'max_size_cycle' )
348351
349352 def __init__ (self , loaders : Any , mode : str = 'min_size' ):
350353 """
351-
352354 Args:
353355 loaders: the loaders to sample from. Can be all kind of collection
354356 mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and
355357 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones.
356-
357358 """
358359 if mode not in self .SUPPORTED_MODES :
359360 raise MisconfigurationException (f"Invalid Mode: { mode } " )
@@ -371,6 +372,84 @@ def __init__(self, loaders: Any, mode: str = 'min_size'):
371372 if self .mode == 'max_size_cycle' :
372373 self ._wrap_loaders_max_size_cycle ()
373374
375+ self ._loaders_iter_state_dict = None
376+ self ._iterator = None # assigned in __iter__
377+
378+ @staticmethod
379+ def _state_dict_fn (dataloader : DataLoader , iterator : Optional [Iterator ], num_batches_processed : int ) -> Dict :
380+ # find next worker if multiple workers were used
381+ state = _find_current_worker (iterator )
382+ if isinstance (dataloader .dataset , CaptureIterableDataset ):
383+ # the sampler state dict are extracted in `CombinedLoaderIterator`
384+ if iterator is not None and getattr (iterator , "_sampler_state_dict" , None ) is not None :
385+ state .update (iterator ._sampler_state_dict [0 ])
386+ else :
387+ # fetch directly from fast forward sampler
388+ state .update (dataloader .fast_forward_sampler .state_dict (num_batches_processed ))
389+ return DataLoaderDict (state )
390+
391+ def state_dict (self , num_batches_processed : int ) -> Dict :
392+ """
393+ The state dict includes all states from wrapped dataloaders and their samplers through the
394+ ``CaptureIterableDataset`` and fast-forward samplers.
395+
396+ Args:
397+ num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
398+ may have already prefetched more batches by the time a state dict is requested.
399+ """
400+ if not _fault_tolerant_enabled ():
401+ return DataLoaderDict ()
402+
403+ state_dict_fn = partial (self ._state_dict_fn , num_batches_processed = num_batches_processed )
404+
405+ return apply_to_collections (self .loaders , self ._iterator .loader_iters , (Iterator , DataLoader ), state_dict_fn )
406+
407+ def load_state_dict (self , state_dict ):
408+ # store the samplers state.
409+ # They would be reloaded once the ``CombinedIterator`` as been created
410+ # and the workers are created.
411+ self ._loaders_iter_state_dict = state_dict
412+
413+ def mock_reset_fn (self , * _ , ** __ ):
414+ pass
415+
416+ # mock reset call, so we can rotate the ``_worker_queue_idx_cycle`` to failed worker
417+ # and get the first batch from it
418+ _MultiProcessingDataLoaderIter ._original_reset = _MultiProcessingDataLoaderIter ._reset
419+ _MultiProcessingDataLoaderIter ._reset = mock_reset_fn
420+
421+ def on_restart (self , iterator : Iterator ):
422+ if not self ._loaders_iter_state_dict :
423+ return
424+
425+ # this happen inside the workers if any were specificied.
426+
427+ def create_loader_iters (dataloader : DataLoader , state_dict : DataLoaderDict ):
428+ if isinstance (dataloader .dataset , CaptureIterableDataset ):
429+ # provide the ``state_dict`` to the ``CaptureIterableDataset``
430+ # as it is responsible for passing down the state to associated ``FastForwardSampler``
431+ dataloader .dataset .load_state_dict (state_dict )
432+ else :
433+ # for ``Mapping-based`` dataset, the ``fast_forward_sampler`` was attached
434+ # on the dataloader for simplicity
435+ dataloader .fast_forward_sampler .load_state_dict (state_dict )
436+
437+ # cycle back the iterator to the failed worker if multiple workers were provided
438+ iterator = _cycle_to_next_worker_and_reset (dataloader , state_dict )
439+
440+ if isinstance (dataloader .dataset , CaptureIterableDataset ):
441+ # remove keys related to iterator
442+ state_dict = {k : v for k , v in state_dict .items () if k not in ("num_worker" , "previous_worker" )}
443+ # need to re-attach the state dict into the iterator for future collection.
444+ iterator ._sampler_state_dict = [state_dict ]
445+ return iterator
446+
447+ # apply the ``create_loader_iters`` on the collection of ``DataLoader / Iterator``.
448+ # each ``Iterator``` was created from the ``DataLoader``.
449+ iterator ._loader_iters = apply_to_collections (
450+ self .loaders , self ._loaders_iter_state_dict , (DataLoader , DataLoaderDict ), create_loader_iters
451+ )
452+
374453 @property
375454 def sampler (self ) -> Union [Iterable , Sequence , Mapping ]:
376455 """Return a collections of samplers extracting from loaders."""
@@ -382,7 +461,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
382461
383462 Returns:
384463 the wrapped loaders
385-
386464 """
387465 all_lengths = apply_to_collection (self .loaders , Iterable , get_len , wrong_dtype = (Sequence , Mapping ))
388466
@@ -398,7 +476,18 @@ def __iter__(self) -> Any:
398476 """
399477 Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.
400478 """
401- return CombinedLoaderIterator (self .loaders )
479+
480+ # prevent ``NotImplementedError`` from PyTorch:
481+ # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
482+ def __getstate__patch__ (* _ ):
483+ return {}
484+
485+ _BaseDataLoaderIter .__getstate__ = __getstate__patch__
486+ iterator = CombinedLoaderIterator (self .loaders )
487+ # handle fault tolerant restart logic.
488+ self .on_restart (iterator )
489+ self ._iterator = iterator
490+ return iterator
402491
403492 @staticmethod
404493 def _calc_num_batches (loaders : Any ) -> Union [int , float ]:
@@ -410,7 +499,6 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]:
410499
411500 Returns:
412501 length: the minimum length of loaders
413-
414502 """
415503 all_lengths = apply_to_collection (loaders , Iterable , get_len , wrong_dtype = (Sequence , Mapping ))
416504
@@ -429,10 +517,8 @@ class CombinedLoaderIterator(object):
429517
430518 def __init__ (self , loaders : Any ):
431519 """
432-
433520 Args:
434521 loaders: the loaders to sample from. Can be all kind of collection
435-
436522 """
437523 self .loaders = loaders
438524 self ._loader_iters = None
@@ -456,7 +542,6 @@ def __next__(self) -> Any:
456542
457543 Returns:
458544 a collections of batch data
459-
460545 """
461546 return self .request_next_batch (self .loader_iters )
462547
@@ -470,9 +555,23 @@ def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any:
470555
471556 Returns
472557 Any: a collections of batch data
473-
474558 """
475- return apply_to_collection (loader_iters , Iterator , next )
559+
560+ def next_fn (iterator : Iterator ):
561+ batch = next (iterator )
562+ if not _fault_tolerant_enabled ():
563+ return batch
564+ # when fault tolerant is enabled, the iterator will return
565+ # ``FastForwardSampler`` state_dict metadata
566+ # along side with the user data.
567+ # the metadata are extracted and store directly on the iterator
568+ # to simplify the collection on ``state_dict`` call.
569+ batch , samplers_state_dict = CaptureIterableDataset .extract_samplers_state_dict_from_batch (batch )
570+ # store the ``sampler_state_dict`` on the iterator
571+ CaptureIterableDataset .store_samplers_state_dict (iterator , samplers_state_dict )
572+ return batch
573+
574+ return apply_to_collection (loader_iters , Iterator , next_fn )
476575
477576 @staticmethod
478577 def create_loader_iters (
@@ -486,7 +585,6 @@ def create_loader_iters(
486585
487586 Returns
488587 a collections of iterators
489-
490588 """
491589 # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences
492590 return apply_to_collection (loaders , Iterable , iter , wrong_dtype = (Sequence , Mapping ))
0 commit comments