-
-
Notifications
You must be signed in to change notification settings - Fork 3k
Open
Labels
featuretopic-paramspecPEP 612, ParamSpec, ConcatenatePEP 612, ParamSpec, Concatenatetopic-typed-dict
Description
Feature
Related to
TypedDict unpacking in ParamSpec would work just like it works now in Callables.
P = ParamSpec("P")
class C(Generic[P]):
def __init__(self, f: Callable[P, None]): ...
class Args(TypedDict):
x: int
y: str
def f(*, x: int, y: str) -> None: ...
c: C[[Unpack[Args]]] = C(f) # OK
d: C[[int, str]] = = C(f) # error because `f` expects keyword argumentsPitch
In order to express a callable type with keyword arguments, you can use a call protocol, but this doesn't work for other classes that are generic in ParamSpec. For example, in pytorch, network layers have to inherit from Module which should be typed approximately like this (using Python 3.12 generic syntax):
class Module[T, **P]:
@abstractmethod
def forward(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
# do other stuff
return self.forward(*args, **kwargs)But what to do if I want to override forward with an optional argument?
class Dense(Module[Tensor, [Tensor, bool]]):
def forward(self, x: Tensor, *, with_dropout: bool = False):
# my implementation
return xWith TypedDict unpacking in ParamSpec:
class ExtraArgs(TypedDict):
with_dropout: bool
class Dense(Module[Tensor, [Tensor, Unpack[ExtraArgs]]]):
def forward(self, x: Tensor, *, with_dropout: bool = False):
# my implementation
return xjorenham, mirceamironenco, kieran-lock and Daraan
Metadata
Metadata
Assignees
Labels
featuretopic-paramspecPEP 612, ParamSpec, ConcatenatePEP 612, ParamSpec, Concatenatetopic-typed-dict