1313# limitations under the License.
1414
1515from collections .abc import Iterable , Iterator , Mapping , Sequence
16- from dataclasses import dataclass , field
16+ from dataclasses import asdict , dataclass , field
1717from functools import partial
1818from typing import Any , Callable , Dict , List , Optional , Union
1919
2020import torch
2121from torch .utils .data import Dataset
22- from torch .utils .data .dataloader import _BaseDataLoaderIter , _MultiProcessingDataLoaderIter , DataLoader
22+ from torch .utils .data .dataloader import _BaseDataLoaderIter , DataLoader
2323from torch .utils .data .dataset import IterableDataset
2424
2525from pytorch_lightning .utilities .apply_func import apply_to_collection , apply_to_collections
2626from pytorch_lightning .utilities .auto_restart import (
27- _cycle_to_next_worker_and_reset ,
28- _find_current_worker ,
27+ _find_fast_forward_samplers ,
2928 CaptureIterableDataset ,
29+ CaptureMapDataset ,
30+ IteratorState ,
31+ MergedIteratorState ,
32+ patch_dataloader_iterator ,
3033)
3134from pytorch_lightning .utilities .data import get_len
3235from pytorch_lightning .utilities .exceptions import MisconfigurationException
@@ -167,6 +170,7 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle
167170 self .loader = loader
168171 self ._loader_iter = None
169172 self .counter = 0
173+ self .state = state
170174
171175 def __iter__ (self ) -> Any :
172176 """
@@ -176,6 +180,7 @@ def __iter__(self) -> Any:
176180 CycleIterator: self
177181 """
178182 self .counter = 0
183+ self .state .reset ()
179184 self ._loader_iter = iter (self .loader )
180185 return self
181186
@@ -205,6 +210,12 @@ def __next__(self) -> Any:
205210 raise StopIteration
206211
207212 self ._loader_iter = iter (self .loader )
213+ # if fault tolerant is enabled, we need to patch the iterator to collect the states
214+ # before the batch gets returned.
215+ fetcher = getattr (self .loader , "_lightning_fetcher" , None )
216+ if fetcher :
217+ patch_dataloader_iterator (self .loader , self ._loader_iter , fetcher )
218+
208219 return next (self ._loader_iter )
209220
210221 finally :
@@ -302,11 +313,6 @@ def __len__(self) -> int:
302313 return self ._calc_num_data (self .datasets , self .mode )
303314
304315
305- class DataLoaderDict (Dict ):
306- # behaves exactly like a dict, this is used to simplify apply_to_collection.
307- pass
308-
309-
310316class CombinedLoader :
311317 """
312318 Combines different dataloaders and allows sampling in parallel.
@@ -360,80 +366,110 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
360366 self ._iterator = None # assigned in __iter__
361367
362368 @staticmethod
363- def _state_dict_fn (dataloader : DataLoader , iterator : Optional [Iterator ], num_batches_processed : int ) -> Dict :
364- # find next worker if multiple workers were used
365- state = _find_current_worker (iterator )
366- if isinstance (dataloader .dataset , CaptureIterableDataset ):
367- # the sampler state dict are extracted in `CombinedLoaderIterator`
368- if iterator is not None and getattr (iterator , "_sampler_state_dict" , None ) is not None :
369- state .update (iterator ._sampler_state_dict [0 ])
370- else :
371- # fetch directly from fast forward sampler
372- state .update (dataloader .fast_forward_sampler .state_dict (num_batches_processed ))
373- return DataLoaderDict (state )
374-
375- def state_dict (self , num_batches_processed : int ) -> Dict :
369+ def _state_dict_fn (dataloader : DataLoader , iterator : Optional [Iterator ], has_completed : int ) -> Dict :
370+ if isinstance (dataloader , CycleIterator ):
371+ iterator = dataloader ._loader_iter
372+ state = getattr (iterator , "state" , None ) if has_completed else getattr (iterator , "previous_state" , None )
373+ if state :
374+ return asdict (state )
375+ return {}
376+
377+ def state_dict (self , has_completed : bool = False ) -> Dict :
376378 """
377379 The state dict includes all states from wrapped dataloaders and their samplers through the
378380 ``CaptureIterableDataset`` and fast-forward samplers.
379381
380382 Args:
381- num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
382- may have already prefetched more batches by the time a state dict is requested .
383+ has_completed: whether the current state of data fetching is considered completed or not. If it is, the
384+ current state gets returned, otherwise the previously cached state.
383385 """
384- if not _fault_tolerant_training ():
385- return DataLoaderDict ()
386-
387- state_dict_fn = partial (self ._state_dict_fn , num_batches_processed = num_batches_processed )
386+ if not _fault_tolerant_training () or self ._iterator is None :
387+ return {}
388388
389- return apply_to_collections (self .loaders , self ._iterator .loader_iters , (Iterator , DataLoader ), state_dict_fn )
389+ return apply_to_collections (
390+ self .loaders ,
391+ self ._iterator .loader_iters ,
392+ (Iterator , DataLoader ),
393+ partial (self ._state_dict_fn , has_completed = has_completed ),
394+ )
390395
391- def load_state_dict (self , state_dict ):
396+ def load_state_dict (self , state_dict ) -> None :
392397 # store the samplers state.
393398 # They would be reloaded once the `CombinedIterator` as been created
394399 # and the workers are created.
395400 self ._loaders_iter_state_dict = state_dict
396401
397- def mock_reset_fn (self , * _ , ** __ ):
398- pass
399-
400- # mock reset call, so we can rotate the `_worker_queue_idx_cycle` to failed worker
401- # and get the first batch from it
402- _MultiProcessingDataLoaderIter ._original_reset = _MultiProcessingDataLoaderIter ._reset
403- _MultiProcessingDataLoaderIter ._reset = mock_reset_fn
404-
405- def on_restart (self , iterator : Iterator ):
402+ def on_restart (self , iterator : Iterator ) -> None :
406403 if not self ._loaders_iter_state_dict :
407404 return
408405
409- # this happen inside the workers if any were specificied.
406+ def create_loader_iters (dataloader : DataLoader , state_dict : Dict ) -> Iterator :
407+ """Function used to reload the iterator state before once the workers are created."""
408+
409+ dataloader_to_iter_on = dataloader
410+ if isinstance (dataloader , CycleIterator ):
411+ dataloader = dataloader_to_iter_on .loader
412+
413+ dataset = dataloader .dataset
414+
415+ # We reload the states before creating the workers
416+ # The specific type of dataset will then decide if the state should be applied before or after
417+ # spawning the workers
418+ if isinstance (dataset , CaptureMapDataset ):
419+ iterator_state = state_dict ["state" ][0 ]
420+
421+ if not isinstance (iterator_state , IteratorState ):
422+ iterator_state = IteratorState .from_state_dict (iterator_state )
423+
424+ # reload sampler state
425+ ff_sampler = _find_fast_forward_samplers (dataloader )
426+ ff_sampler .load_state_dict (iterator_state .sampler_state )
427+ # reload dataset state
428+ dataset .load_state_dict (
429+ iterator_state .dataset_state ,
430+ latest_worker_id = state_dict ["latest_worker_id" ],
431+ num_workers = iterator_state .num_workers ,
432+ )
433+
434+ elif isinstance (dataset , CaptureIterableDataset ):
435+ dataset_dict = {
436+ sampler_name : state [0 ]["sampler_state" ] for sampler_name , state in state_dict ["state" ].items ()
437+ }
438+ dataset .load_state_dict (dataset_dict )
410439
411- def create_loader_iters (dataloader : DataLoader , state_dict : DataLoaderDict ):
412- if isinstance (dataloader .dataset , CaptureIterableDataset ):
413- # provide the `state_dict` to the `CaptureIterableDataset`
414- # as it is responsible for passing down the state to associated `FastForwardSampler`
415- dataloader .dataset .load_state_dict (state_dict )
416440 else :
417- # for `Mapping-based` dataset, the `fast_forward_sampler` was attached
418- # on the dataloader for simplicity
419- dataloader .fast_forward_sampler .load_state_dict (state_dict )
441+ raise MisconfigurationException (
442+ "This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
443+ )
444+
445+ # We finally spawned the workers if any.
446+ it = iter (dataloader_to_iter_on )
420447
421- # cycle back the iterator to the failed worker if multiple workers were provided
422- iterator = _cycle_to_next_worker_and_reset ( dataloader , state_dict )
448+ # restore caching state
449+ state = MergedIteratorState . from_state_dict ( state_dict )
423450
424- if isinstance (dataloader .dataset , CaptureIterableDataset ):
425- # remove keys related to iterator
426- state_dict = {k : v for k , v in state_dict .items () if k not in ("num_worker" , "previous_worker" )}
427- # need to re-attach the state dict into the iterator for future collection.
428- iterator ._sampler_state_dict = [state_dict ]
429- return iterator
451+ if isinstance (dataloader_to_iter_on , CycleIterator ):
452+ it ._loader_iter .state = state
453+ else :
454+ it .state = state
455+ return it
456+
457+ # create an un-existing token, so it doesn't activate for something else than an iterator.
458+ class DataLoaderDict (dict ):
459+ pass
430460
431461 # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`.
432462 # each `Iterator` was created from the `DataLoader`.
433463 iterator ._loader_iters = apply_to_collections (
434- self .loaders , self ._loaders_iter_state_dict , (DataLoader , DataLoaderDict ), create_loader_iters
464+ self .loaders ,
465+ self ._loaders_iter_state_dict ,
466+ (Iterable , DataLoaderDict ),
467+ create_loader_iters ,
468+ wrong_dtype = (Sequence , Mapping ),
435469 )
436470
471+ self ._loaders_iter_state_dict = None
472+
437473 @property
438474 def sampler (self ) -> Union [Iterable , Sequence , Mapping ]:
439475 """Return a collections of samplers extracting from loaders."""
@@ -457,7 +493,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
457493 self .loaders = apply_to_collection (
458494 self .loaders , Iterable , CycleIterator , length = length , state = state , wrong_dtype = (Sequence , Mapping )
459495 )
460-
461496 state .reset ()
462497
463498 def __iter__ (self ) -> Any :
0 commit comments