Skip to content

Commit 05c55f6

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1874e1a commit 05c55f6

File tree

1 file changed

+12
-8
lines changed
  • src/pytorch_lightning/utilities

1 file changed

+12
-8
lines changed

src/pytorch_lightning/utilities/meta.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _no_dispatch() -> Iterator[None]:
6464
yield
6565
finally:
6666
del guard
67-
67+
6868
def _handle_arange(func: Callable, args: Any, kwargs: Any) -> Tensor:
6969
kwargs["device"] = torch.device("cpu")
7070
return torch.empty_like(func(*args, **kwargs), device="meta")
@@ -74,7 +74,7 @@ def _handle_tril(func: Callable, args: Any, kwargs: Any) -> Union[Tensor, Any]:
7474
return torch.empty_like(args[0], device="meta")
7575

7676
return NotImplemented
77-
77+
7878
class _MetaContext(Tensor):
7979
_op_handlers: Dict[Callable, Callable] = {}
8080

@@ -91,7 +91,7 @@ def _ensure_handlers_initialized(cls) -> None:
9191
)
9292

9393
@classmethod
94-
def __torch_dispatch__(cls, func: Callable, types: Any, args: Any=(), kwargs: Optional[Any]=None) -> Any:
94+
def __torch_dispatch__(cls, func: Callable, types: Any, args: Any = (), kwargs: Optional[Any] = None) -> Any:
9595
cls._ensure_handlers_initialized()
9696

9797
op_handler: Optional[Callable]
@@ -112,8 +112,10 @@ def __torch_dispatch__(cls, func: Callable, types: Any, args: Any=(), kwargs: Op
112112

113113
return func(*args, **(kwargs if kwargs is not None else {}))
114114

115-
def init_meta(module_fn: Callable[..., Module], *args: Any, **kwargs: Any) -> Union[Module, MisconfigurationException]:
116-
def create_instance(module: Optional[Any]=None) -> Module:
115+
def init_meta(
116+
module_fn: Callable[..., Module], *args: Any, **kwargs: Any
117+
) -> Union[Module, MisconfigurationException]:
118+
def create_instance(module: Optional[Any] = None) -> Module:
117119
if module:
118120
module.__init__(*args, **kwargs)
119121
return module
@@ -144,7 +146,9 @@ def is_meta_init() -> bool:
144146

145147
else:
146148

147-
def init_meta(module_fn: Callable[..., Module], *args: Any, **kwargs: Any) -> Union[Module, MisconfigurationException]:
149+
def init_meta(
150+
module_fn: Callable[..., Module], *args: Any, **kwargs: Any
151+
) -> Union[Module, MisconfigurationException]:
148152
if not _TORCH_GREATER_EQUAL_1_10:
149153
return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")
150154

@@ -194,7 +198,7 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
194198

195199

196200
# cache subclasses to optimize the search when resetting the meta device later on.
197-
__STORAGE_META__: Dict[Type, Tuple]= {}
201+
__STORAGE_META__: Dict[Type, Tuple] = {}
198202
__CREATED_MODULES__: Set[Type] = set()
199203

200204

@@ -287,7 +291,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Type:
287291
cls.add_subclasses(subclass)
288292
with cls.instantiation_context():
289293
obj = init_meta(subclass, *args, **kwargs)
290-
if(isinstance(obj, Exception)):
294+
if isinstance(obj, Exception):
291295
raise obj
292296

293297
obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) # type: ignore[assignment]

0 commit comments

Comments
 (0)