-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Description
With ParamSpec we can completely annotate jit signature.
Thus type checking can be performed with transformed function. (and lower)
T = TypeVar("T")
P = ParamSpec("P")
def jit(fun: Callable[P, T], ...) -> stages.Wrapped[P, T]:
...
# stages.py
T_co = TypeVar("T_co", covariant=True)
P_contra = ParamSpec("P_contra", contravariant=True)
class Wrapped(Protocol[P_contra, T_co]):
def __call__(self, *args: P_contra.args, **kwargs: P_contra.kwargs) -> T_co:
...
def lower(self, *args: P_contra.args, **kwargs: P_contra.kwargs) -> Lowered:
...Just like #9999 , and it works almost perfectly with Visual Studio Code + pylance, signature(including parameters name) of function is fully preserved.
However, mypy currently doesn't handle partial + function-transform correctly. python/mypy#12593
We can wait mypy fix this bug, or make some workaround.
Some workaround option:
make jit a polymorphism decorator + decorator factory.
toy example
from typing import TypeVar, Callable, overload
from typing_extensions import ParamSpec, Protocol, Literal
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
P = ParamSpec('P')
class Wrapped(Protocol[P, T_co]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T_co:
pass
@overload
def jit(f: Callable[P, T], *, a: bool) -> Wrapped[P, T]:
...
@overload
def jit(f: Literal[None] = None, *, a: bool) -> Callable[[Callable[P, T]], Wrapped[P, T]]:
...
def jit(f=None, *, a):
if f is None:
return lambda f: jit(f, a=a)
# original code here
return f
@jit(a=True)
def f(x: int, y: int):
return x + y
f(0, 1) #mypy ok
f(0, 'a') # mypy error Argument 2 to "__call__" of "Wrapped" has incompatible type "str"; expected "int"This break "Our general philosophy is to make the JAX core API as simple and explicit as possible" @mattjj #184 (comment)
I personally like this option, but it breaks the philosophy of JAX.
add a dedicated decorator factory function and use it internally
toy example
from typing import TypeVar, Callable
from typing_extensions import ParamSpec, Protocol
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
P = ParamSpec('P')
class Wrapped(Protocol[P, T_co]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T_co:
pass
def jit(f: Callable[P, T], *, a: bool) -> Wrapped[P, T]:
# original code here
return f
def _make_jit(*, a: bool) -> Callable[[Callable[P, T]], Wrapped[P, T]]:
return lambda f: jit(f, a=a) # type: ignore # a mypy bug fixed in mypy-0.950
@_make_jit(a=True)
def f(x: int, y: int):
return x + y
f(0, 0) # mypy ok
f(0, 'a') # mypy error Argument 2 to "__call__" of "Wrapped" has incompatible type "str"; expected "int"I think this option is satisfactory for everyone. (also can be a public API)
In addtion, we can write a function that can create decorator factory function without losing type information, but current(even the next release) mypy has bug in this case as well. python/mypy#12595
from typing import TypeVar, Callable
from typing_extensions import Concatenate, ParamSpec, Protocol
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
P = ParamSpec('P')
class Wrapped(Protocol[P, T_co]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T_co:
pass
def jit(f: Callable[P, T], *, a: bool) -> Wrapped[P, T]:
# original code here
return f
_T = TypeVar('_T')
_P = ParamSpec('_P')
def _make_factory(
transform: Callable[Concatenate[Callable[P, T], _P], _T] # type: ignore
# mypy can't recognize Concatenate, mypy-0.950 has fixed it
) -> Callable[_P, Callable[[Callable[P, T]], _T]]:
return lambda *args, **kwargs: lambda f: transform(f, *args, **kwargs)
_make_jit: Callable[..., Callable[[Callable[P, T]], Wrapped[P, T]]] = _make_factory(jit)
# mypy need a type annotation, but this annotation will erase signature of jit -> bad for vscode
# mypy-0.950 has not fixed it
@_make_jit(a=True)
def f(x: int, y: int):
return x + y
f(0, 0) # mypy ok
f(0, 'a') # mypy error Argument 2 to "__call__" of "Wrapped" has incompatible type "str"; expected "int"