1414import functools
1515import inspect
1616import os
17+ from collections import OrderedDict
1718from contextlib import contextmanager
1819from dataclasses import fields
1920from functools import partial
@@ -220,11 +221,11 @@ def _get_dataloader_init_args_and_kwargs(
220221 if not isinstance (dataloader , DataLoader ):
221222 raise ValueError (f"The dataloader { dataloader } needs to subclass `torch.utils.data.DataLoader`" )
222223
223- was_wrapped = hasattr (dataloader , "__pl_dl_args " )
224+ was_wrapped = hasattr (dataloader , "__pl_saved_args " )
224225 if was_wrapped :
225- dl_args = dataloader .__pl_dl_args
226- dl_kwargs = dataloader .__pl_dl_kwargs
227- arg_names = dataloader .__pl_dl_arg_names
226+ dl_args = dataloader .__pl_saved_args
227+ dl_kwargs = dataloader .__pl_saved_kwargs
228+ arg_names = dataloader .__pl_saved_arg_names
228229 original_dataset = dataloader .__dataset # we have this saved from _wrap_init
229230 else :
230231 # get the dataloader instance attributes
@@ -323,6 +324,9 @@ def _dataloader_init_kwargs_resolve_sampler(
323324 If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
324325 Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a
325326 `FastForwardSampler`.
327+
328+ If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated
329+ automatically, since `poptorch.DataLoader` will try to increase the batch_size
326330 """
327331 fault_tolerant_mode = _FaultTolerantMode .detect_current_mode ()
328332 batch_sampler = getattr (dataloader , "batch_sampler" )
@@ -341,11 +345,59 @@ def _dataloader_init_kwargs_resolve_sampler(
341345 "when running on multiple IPU devices."
342346 )
343347 elif type (batch_sampler ) is not BatchSampler or is_predicting :
344- batch_sampler = type (batch_sampler )(
345- sampler ,
346- batch_size = batch_sampler .batch_size ,
347- drop_last = (False if is_predicting else batch_sampler .drop_last ),
348- )
348+ batch_sampler_cls = type (batch_sampler )
349+ if hasattr (batch_sampler , "__pl_saved_args" ):
350+ args = batch_sampler .__pl_saved_args
351+ kwargs = batch_sampler .__pl_saved_kwargs
352+ default_kwargs = batch_sampler .__pl_saved_default_kwargs
353+ arg_names = batch_sampler .__pl_saved_arg_names
354+
355+ if is_predicting :
356+ success , args , kwargs = _replace_value_in_saved_args (
357+ "drop_last" , False , args , kwargs , default_kwargs , arg_names
358+ )
359+ if not success :
360+ rank_zero_warn (
361+ f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however "
362+ f"it seems the class `{ batch_sampler_cls .__qualname__ } ` does not support it. "
363+ "Your predictions might be incomplete. To mitigate this, expose `drop_last` in "
364+ "the `__init__` method of your custom class."
365+ )
366+
367+ success , args , kwargs = _replace_value_in_saved_args (
368+ "sampler" , sampler , args , kwargs , default_kwargs , arg_names
369+ )
370+ if not success :
371+ raise TypeError (
372+ "Trying to inject a modified sampler into the batch sampler; however, it seems the class "
373+ f"`{ batch_sampler_cls .__qualname__ } ` does not have an argument called `sampler.` To mitigate "
374+ "this, expose an argument `sampler` in the `__init__` method of your custom class."
375+ )
376+
377+ batch_sampler = batch_sampler_cls (* args , ** kwargs )
378+ else :
379+ try :
380+ batch_sampler = batch_sampler_cls (
381+ sampler ,
382+ batch_size = batch_sampler .batch_size ,
383+ drop_last = (False if is_predicting else batch_sampler .drop_last ),
384+ )
385+ except TypeError as e :
386+ import re
387+
388+ match = re .match (r".*__init__\(\) (got multiple values)|(missing \d required)" , str (e ))
389+ if not match :
390+ # an unexpected `TypeError`, continue failure
391+ raise
392+
393+ # There could either be too few or too many arguments. Customizing the message based on this doesn't
394+ # make much sense since our MisconfigurationException is going to be raised from the original one.
395+ raise MisconfigurationException (
396+ "We tried to re-instantiate your custom batch sampler and failed. "
397+ "To mitigate this, either follow the API of `BatchSampler` or instantiate "
398+ "your custom batch sampler inside `*_dataloader` hooks of your module."
399+ ) from e
400+
349401 if is_predicting :
350402 batch_sampler = IndexBatchSamplerWrapper (batch_sampler )
351403
@@ -368,39 +420,73 @@ def _dataloader_init_kwargs_resolve_sampler(
368420 return {"sampler" : sampler , "shuffle" : False , "batch_sampler" : None }
369421
370422
423+ def _replace_value_in_saved_args (
424+ replace_key : str ,
425+ replace_value : Any ,
426+ args : Tuple [Any , ...],
427+ kwargs : Dict [str , Any ],
428+ default_kwargs : Dict [str , Any ],
429+ arg_names : Tuple [str , ...],
430+ ) -> Tuple [bool , Tuple [Any , ...], Dict [str , Any ]]:
431+ """Tries to replace an argument value in a saved list of args and kwargs.
432+
433+ Returns a tuple indicating success of the operation and modified saved args and kwargs
434+ """
435+
436+ if replace_key in arg_names :
437+ replace_index = arg_names .index (replace_key )
438+ args = args [:replace_index ] + (replace_value ,) + args [replace_index + 1 :]
439+ return True , args , kwargs
440+ elif replace_key in kwargs or replace_key in default_kwargs :
441+ kwargs [replace_key ] = replace_value
442+ return True , args , kwargs
443+
444+ return False , args , kwargs
445+
446+
371447def _auto_add_worker_init_fn (dataloader : DataLoader , rank : int ) -> None :
372448 if int (os .environ .get ("PL_SEED_WORKERS" , 0 )) and dataloader .worker_init_fn is None :
373449 dataloader .worker_init_fn = partial (pl_worker_init_function , rank = rank )
374450
375451
376- def _wrap_dataloader_init (init : Callable ) -> Callable :
377- """Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation
378- of custom subclasses."""
452+ def _wrap_init_method (init : Callable , store_explicit_arg : Optional [ str ] = None ) -> Callable :
453+ """Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
454+ :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
379455
380456 @functools .wraps (init )
381- def wrapper (obj : DataLoader , * args : Any , ** kwargs : Any ) -> None :
457+ def wrapper (obj : Any , * args : Any , ** kwargs : Any ) -> None :
382458 # We need to inspect `init`, as inspecting `obj.__init__`
383459 # can lead to inspecting the wrong function with multiple inheritance
384460 params = inspect .signature (init ).parameters
385- param_names = tuple (
386- param .name
461+
462+ parameters_defaults = OrderedDict (
463+ (param .name , param .default )
387464 for param in params .values ()
388465 if param .name != "self" and param .kind not in (param .VAR_POSITIONAL , param .VAR_KEYWORD )
389466 )
390- param_names = param_names [: len (args )]
391467
392- if not hasattr (obj , "__pl_dl_args" ):
393- obj .__pl_dl_args = args
394- obj .__pl_dl_kwargs = kwargs
395- obj .__pl_dl_arg_names = param_names
468+ param_names = tuple (parameters_defaults )[: len (args )]
396469
397- # We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader)
470+ default_kwargs = {
471+ name : value
472+ for name , value in parameters_defaults .items ()
473+ if name not in kwargs and name not in param_names and value != inspect .Parameter .empty
474+ }
475+
476+ if not hasattr (obj , "__pl_saved_args" ):
477+ obj .__pl_saved_args = args
478+ obj .__pl_saved_kwargs = kwargs
479+ obj .__pl_saved_arg_names = param_names
480+ obj .__pl_saved_default_kwargs = default_kwargs
481+
482+ # We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
398483 # so that we can be sure, that it will not get changed anymore.
399484 # That is why we are setting this in every `__init__`
400- if "dataset" in param_names :
401- setattr (obj , "__dataset" , args [param_names .index ("dataset" )])
402- elif "dataset" in kwargs :
403- setattr (obj , "__dataset" , kwargs ["dataset" ])
485+ if store_explicit_arg is not None :
486+ if store_explicit_arg in param_names :
487+ setattr (obj , f"__{ store_explicit_arg } " , args [param_names .index (store_explicit_arg )])
488+ elif store_explicit_arg in kwargs :
489+ setattr (obj , f"__{ store_explicit_arg } " , kwargs [store_explicit_arg ])
404490
405491 init (obj , * args , ** kwargs )
406492
@@ -422,15 +508,17 @@ def recurse(cl: Type[Any]) -> None:
422508
423509
424510@contextmanager
425- def _replace_dataloader_init_method () -> Generator [None , None , None ]:
426- """This context manager is used to add support for re-instantiation of custom (subclasses) of
427- :class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
428- classes = _get_all_subclasses (DataLoader ) | {DataLoader }
511+ def _replace_init_method (base_cls : Type , store_explicit_arg : Optional [str ] = None ) -> Generator [None , None , None ]:
512+ """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
513+
514+ It patches the ``__init__`` method.
515+ """
516+ classes = _get_all_subclasses (base_cls ) | {base_cls }
429517 wrapped = set ()
430518 for cls in classes :
431519 if cls .__init__ not in wrapped :
432520 cls ._old_init = cls .__init__
433- cls .__init__ = _wrap_dataloader_init (cls .__init__ )
521+ cls .__init__ = _wrap_init_method (cls .__init__ , store_explicit_arg )
434522 wrapped .add (cls .__init__ )
435523 yield
436524 for cls in classes :
@@ -475,13 +563,13 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(
475563
476564
477565def _is_dataloader_shuffled (dataloader : object ) -> bool :
478- if hasattr (dataloader , "__pl_dl_kwargs " ):
566+ if hasattr (dataloader , "__pl_saved_kwargs " ):
479567 # this attribute is not part of PyTorch's DataLoader, but could have been set by
480- # our `_replace_dataloader_init_method ` context manager
481- if "shuffle" in dataloader .__pl_dl_kwargs :
482- return dataloader .__pl_dl_kwargs ["shuffle" ]
483- if "shuffle" in dataloader .__pl_dl_arg_names :
484- return dataloader .__pl_dl_args [dataloader .__pl_dl_arg_names .index ("shuffle" )]
568+ # our `_replace_init_method ` context manager
569+ if "shuffle" in dataloader .__pl_saved_kwargs :
570+ return dataloader .__pl_saved_kwargs ["shuffle" ]
571+ if "shuffle" in dataloader .__pl_saved_arg_names :
572+ return dataloader .__pl_saved_args [dataloader .__pl_saved_arg_names .index ("shuffle" )]
485573 if isinstance (dataloader .dataset , IterableDataset ):
486574 # shuffling is useless with iterable datasets
487575 return False
0 commit comments