@@ -64,7 +64,7 @@ def _no_dispatch() -> Iterator[None]:
6464 yield
6565 finally :
6666 del guard
67-
67+
6868 def _handle_arange (func : Callable , args : Any , kwargs : Any ) -> Tensor :
6969 kwargs ["device" ] = torch .device ("cpu" )
7070 return torch .empty_like (func (* args , ** kwargs ), device = "meta" )
@@ -74,7 +74,7 @@ def _handle_tril(func: Callable, args: Any, kwargs: Any) -> Union[Tensor, Any]:
7474 return torch .empty_like (args [0 ], device = "meta" )
7575
7676 return NotImplemented
77-
77+
7878 class _MetaContext (Tensor ):
7979 _op_handlers : Dict [Callable , Callable ] = {}
8080
@@ -91,7 +91,7 @@ def _ensure_handlers_initialized(cls) -> None:
9191 )
9292
9393 @classmethod
94- def __torch_dispatch__ (cls , func : Callable , types : Any , args : Any = (), kwargs : Optional [Any ]= None ) -> Any :
94+ def __torch_dispatch__ (cls , func : Callable , types : Any , args : Any = (), kwargs : Optional [Any ] = None ) -> Any :
9595 cls ._ensure_handlers_initialized ()
9696
9797 op_handler : Optional [Callable ]
@@ -112,8 +112,10 @@ def __torch_dispatch__(cls, func: Callable, types: Any, args: Any=(), kwargs: Op
112112
113113 return func (* args , ** (kwargs if kwargs is not None else {}))
114114
115- def init_meta (module_fn : Callable [..., Module ], * args : Any , ** kwargs : Any ) -> Union [Module , MisconfigurationException ]:
116- def create_instance (module : Optional [Any ]= None ) -> Module :
115+ def init_meta (
116+ module_fn : Callable [..., Module ], * args : Any , ** kwargs : Any
117+ ) -> Union [Module , MisconfigurationException ]:
118+ def create_instance (module : Optional [Any ] = None ) -> Module :
117119 if module :
118120 module .__init__ (* args , ** kwargs )
119121 return module
@@ -144,7 +146,9 @@ def is_meta_init() -> bool:
144146
145147else :
146148
147- def init_meta (module_fn : Callable [..., Module ], * args : Any , ** kwargs : Any ) -> Union [Module , MisconfigurationException ]:
149+ def init_meta (
150+ module_fn : Callable [..., Module ], * args : Any , ** kwargs : Any
151+ ) -> Union [Module , MisconfigurationException ]:
148152 if not _TORCH_GREATER_EQUAL_1_10 :
149153 return MisconfigurationException ("`init_meta` is supported from PyTorch 1.10.0" )
150154
@@ -194,7 +198,7 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
194198
195199
196200# cache subclasses to optimize the search when resetting the meta device later on.
197- __STORAGE_META__ : Dict [Type , Tuple ]= {}
201+ __STORAGE_META__ : Dict [Type , Tuple ] = {}
198202__CREATED_MODULES__ : Set [Type ] = set ()
199203
200204
@@ -287,7 +291,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Type:
287291 cls .add_subclasses (subclass )
288292 with cls .instantiation_context ():
289293 obj = init_meta (subclass , * args , ** kwargs )
290- if ( isinstance (obj , Exception ) ):
294+ if isinstance (obj , Exception ):
291295 raise obj
292296
293297 obj .materialize = partial (cls .materialize , materialize_fn = obj .materialize ) # type: ignore[assignment]
0 commit comments