1515import inspect
1616from contextlib import contextmanager
1717from itertools import chain
18- from typing import Any , Callable , Dict , Generator , Iterator , Optional , Set , Type , Union
18+ from typing import Any , Callable , Generator , Iterator , Optional , Set , Type , Union
1919
2020import torch
2121from torch import nn as nn
@@ -110,21 +110,25 @@ def _convert_float_tensor(t: Tensor) -> Tensor:
110110 return output
111111
112112
113- def _wrap_init (f : Callable ) -> Callable :
114- @functools .wraps (f )
115- def wrapper (module : Any , * args : Any , ** kwargs : Dict [str , Any ]) -> None :
116- params = dict (inspect .signature (module ._old_init ).parameters )
113+ def _wrap_init (init : Callable ) -> Callable :
114+ """Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
115+ :class:`~torch.utils.data.DataLoader`."""
116+
117+ @functools .wraps (init )
118+ def wrapper (obj : DataLoader , * args : Any , ** kwargs : Any ) -> None :
119+ params = dict (inspect .signature (obj .__init__ ).parameters )
117120 params .pop ("args" , None )
118121 params .pop ("kwargs" , None )
119- for init_name , init_arg in chain (zip (params , args ), kwargs .items ()):
120- setattr (module , init_name , init_arg )
121- f ( module , * args , ** kwargs )
122+ for arg_name , arg_value in chain (zip (params , args ), kwargs .items ()):
123+ setattr (obj , arg_name , arg_value )
124+ init ( obj , * args , ** kwargs )
122125
123126 return wrapper
124127
125128
126129# https://stackoverflow.com/a/63851681/9201239
127130def _get_all_subclasses (cls : Type [Any ]) -> Set [Type [Any ]]:
131+ """Returns a list of all classes that inherit directly or indirectly from the given class."""
128132 subclasses = set ()
129133
130134 def recurse (cl : Type [Any ]) -> None :
@@ -136,24 +140,17 @@ def recurse(cl: Type[Any]) -> None:
136140 return subclasses
137141
138142
139- def _enable_class (cls : Type [Any ]) -> None :
140- cls ._old_init = cls .__init__
141- cls .__init__ = _wrap_init (cls .__init__ )
142-
143-
144- def _disable_class (cls : Type [Any ]) -> None :
145- cls .__init__ = cls ._old_init
146- del cls ._old_init
147-
148-
149143@contextmanager
150- def _replace_dataloader_init_method () -> Generator :
151- """This context manager is used to support custom :class:`~torch.utils.data.DataLoader."""
144+ def _replace_dataloader_init_method () -> Generator [None , None , None ]:
145+ """This context manager is used to add support for re-instantiation of custom (subclasses) of
146+ :class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
152147 for subclass in _get_all_subclasses (DataLoader ):
153- _enable_class (subclass )
148+ subclass ._old_init = subclass .__init__
149+ subclass .__init__ = _wrap_init (subclass .__init__ )
154150 yield
155151 for subclass in _get_all_subclasses (DataLoader ):
156- _disable_class (subclass )
152+ subclass .__init__ = subclass ._old_init
153+ del subclass ._old_init
157154
158155
159156class _LiteDataLoader :
0 commit comments