Skip to content

Generic Protocol for transformed function #10311

@YouJiacheng

Description

@YouJiacheng

With ParamSpec we can completely annotate jit signature.
Thus type checking can be performed with transformed function. (and lower)

T = TypeVar("T")
P = ParamSpec("P")
def jit(fun: Callable[P, T], ...) -> stages.Wrapped[P, T]:
  ...

# stages.py
T_co = TypeVar("T_co", covariant=True)
P_contra = ParamSpec("P_contra", contravariant=True)
class Wrapped(Protocol[P_contra, T_co]):
  def __call__(self, *args: P_contra.args, **kwargs: P_contra.kwargs) -> T_co:
    ...

  def lower(self, *args: P_contra.args, **kwargs: P_contra.kwargs) -> Lowered:
    ...

Just like #9999 , and it works almost perfectly with Visual Studio Code + pylance, signature(including parameters name) of function is fully preserved.
However, mypy currently doesn't handle partial + function-transform correctly. python/mypy#12593
We can wait mypy fix this bug, or make some workaround.
Some workaround option:

make jit a polymorphism decorator + decorator factory.

toy example

from typing import TypeVar, Callable, overload
from typing_extensions import ParamSpec, Protocol, Literal

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
P = ParamSpec('P')

class Wrapped(Protocol[P, T_co]):
  def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T_co:
    pass

@overload
def jit(f: Callable[P, T], *, a: bool) -> Wrapped[P, T]:
    ...

@overload
def jit(f: Literal[None] = None, *, a: bool) -> Callable[[Callable[P, T]], Wrapped[P, T]]:
    ...

def jit(f=None, *, a):
    if f is None:
        return lambda f: jit(f, a=a)
    # original code here
    return f

@jit(a=True)
def f(x: int, y: int):
    return x + y

f(0, 1) #mypy ok
f(0, 'a') # mypy error  Argument 2 to "__call__" of "Wrapped" has incompatible type "str"; expected "int"

This break "Our general philosophy is to make the JAX core API as simple and explicit as possible" @mattjj #184 (comment)
I personally like this option, but it breaks the philosophy of JAX.

add a dedicated decorator factory function and use it internally

toy example

from typing import TypeVar, Callable
from typing_extensions import ParamSpec, Protocol

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
P = ParamSpec('P')

class Wrapped(Protocol[P, T_co]):
  def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T_co:
    pass

def jit(f: Callable[P, T], *, a: bool) -> Wrapped[P, T]:
    # original code here
    return f

def _make_jit(*, a: bool) -> Callable[[Callable[P, T]], Wrapped[P, T]]:
    return lambda f: jit(f, a=a) # type: ignore # a mypy bug fixed in mypy-0.950

@_make_jit(a=True)
def f(x: int, y: int):
    return x + y

f(0, 0) # mypy ok
f(0, 'a') # mypy error  Argument 2 to "__call__" of "Wrapped" has incompatible type "str"; expected "int"

I think this option is satisfactory for everyone. (also can be a public API)
In addtion, we can write a function that can create decorator factory function without losing type information, but current(even the next release) mypy has bug in this case as well. python/mypy#12595

from typing import TypeVar, Callable
from typing_extensions import Concatenate, ParamSpec, Protocol

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
P = ParamSpec('P')

class Wrapped(Protocol[P, T_co]):
  def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T_co:
    pass

def jit(f: Callable[P, T], *, a: bool) -> Wrapped[P, T]:
    # original code here
    return f

_T = TypeVar('_T')
_P = ParamSpec('_P')

def _make_factory(
    transform: Callable[Concatenate[Callable[P, T], _P], _T] # type: ignore
    # mypy can't recognize Concatenate, mypy-0.950 has fixed it
) -> Callable[_P, Callable[[Callable[P, T]], _T]]:
    return lambda *args, **kwargs: lambda f: transform(f, *args, **kwargs)

_make_jit: Callable[..., Callable[[Callable[P, T]], Wrapped[P, T]]] = _make_factory(jit)
# mypy need a type annotation, but this annotation will erase signature of jit -> bad for vscode
# mypy-0.950 has not fixed it

@_make_jit(a=True)
def f(x: int, y: int):
    return x + y

f(0, 0) # mypy ok
f(0, 'a') # mypy error Argument 2 to "__call__" of "Wrapped" has incompatible type "str"; expected "int"

Metadata

Metadata

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions