Skip to content

Commit ba45b45

Browse files
committed
Use ParamSpec in jit annotation; bump MyPy to 1.0
1 parent 008f35a commit ba45b45

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414
- id: flake8
1515

1616
- repo: https://github.com/pre-commit/mirrors-mypy
17-
rev: 'v0.982'
17+
rev: 'v1.0.1'
1818
hooks:
1919
- id: mypy
2020
files: (jax/|tests/typing_test\.py)

jax/_src/api.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@
2222
from __future__ import annotations
2323

2424
import collections
25+
from contextlib import contextmanager, ExitStack
2526
import functools
2627
from functools import partial
2728
import inspect
2829
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
29-
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union, overload)
30+
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union,
31+
overload)
3032

3133
import numpy as np
32-
from contextlib import contextmanager, ExitStack
34+
from typing_extensions import ParamSpec
3335

3436
import jax
3537
from jax._src import linear_util as lu
@@ -105,6 +107,9 @@
105107
F = TypeVar("F", bound=Callable)
106108
T = TypeVar("T")
107109
U = TypeVar("U")
110+
V_co = TypeVar("V_co", covariant=True)
111+
P = ParamSpec("P")
112+
108113

109114
map, unsafe_map = safe_map, map
110115
zip, unsafe_zip = safe_zip, zip
@@ -155,7 +160,7 @@ def _update_debug_special_thread_local(_):
155160

156161

157162
def jit(
158-
fun: Callable,
163+
fun: Callable[P, V_co],
159164
*,
160165
static_argnums: Union[int, Iterable[int], None] = None,
161166
static_argnames: Union[str, Iterable[str], None] = None,
@@ -165,7 +170,7 @@ def jit(
165170
inline: bool = False,
166171
keep_unused: bool = False,
167172
abstracted_axes: Optional[Any] = None,
168-
) -> stages.Wrapped:
173+
) -> stages.Wrapped[P, V_co]:
169174
"""Sets up ``fun`` for just-in-time compilation with XLA.
170175
171176
Args:
@@ -339,7 +344,7 @@ def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
339344
PytreeOfAbstractedAxesSpec = Any
340345

341346
def _python_jit(
342-
fun: Callable,
347+
fun: Callable[P, V_co],
343348
*,
344349
static_argnums: Tuple[int, ...],
345350
static_argnames: Tuple[str, ...],
@@ -349,7 +354,7 @@ def _python_jit(
349354
inline: bool,
350355
keep_unused: bool,
351356
abstracted_axes: Optional[PytreeOfAbstractedAxesSpec],
352-
) -> stages.Wrapped:
357+
) -> stages.Wrapped[P, V_co]:
353358
@wraps(fun)
354359
@api_boundary
355360
def f_jitted(*args, **kwargs):
@@ -483,7 +488,7 @@ def _device_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
483488
return None
484489

485490
def _cpp_jit(
486-
fun: Callable,
491+
fun: Callable[P, V_co],
487492
*,
488493
static_argnums: Tuple[int, ...],
489494
static_argnames: Tuple[str, ...],
@@ -492,7 +497,7 @@ def _cpp_jit(
492497
donate_argnums: Tuple[int, ...],
493498
inline: bool,
494499
keep_unused: bool,
495-
) -> stages.Wrapped:
500+
) -> stages.Wrapped[P, V_co]:
496501
# An implementation of `jit` that tries to do as much as possible in C++.
497502
# The goal of this function is to speed up the time it takes to process the
498503
# arguments, find the correct C++ executable, start the transfer of arguments
@@ -2064,7 +2069,7 @@ def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums,
20642069

20652070

20662071
def _python_pmap(
2067-
fun: Callable,
2072+
fun: Callable[P, V_co],
20682073
axis_name: Optional[AxisName] = None,
20692074
*,
20702075
in_axes=0,
@@ -2075,7 +2080,7 @@ def _python_pmap(
20752080
axis_size: Optional[int] = None,
20762081
donate_argnums: Union[int, Iterable[int]] = (),
20772082
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
2078-
) -> stages.Wrapped:
2083+
) -> stages.Wrapped[P, V_co]:
20792084
"""The Python only implementation."""
20802085
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
20812086
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,

jax/_src/stages.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
import warnings
3434

3535
from dataclasses import dataclass
36-
from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tuple
36+
from typing import (Any, Dict, Generic, List, NamedTuple, Optional, Protocol,
37+
Sequence, Tuple, TypeVar)
38+
from typing_extensions import ParamSpec
3739

3840
import jax
3941
from jax import tree_util
@@ -617,7 +619,11 @@ def compiler_ir(self, dialect: Optional[str] = None) -> Optional[Any]:
617619
return None
618620

619621

620-
class Wrapped(Protocol):
622+
V_co = TypeVar("V_co", covariant=True)
623+
P = ParamSpec("P")
624+
625+
626+
class Wrapped(Protocol, Generic[P, V_co]):
621627
"""A function ready to be specialized, lowered, and compiled.
622628
623629
This protocol reflects the output of functions such as
@@ -626,7 +632,7 @@ class Wrapped(Protocol):
626632
to compilation, and the result compiled prior to execution.
627633
"""
628634

629-
def __call__(self, *args, **kwargs):
635+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> V_co:
630636
"""Executes the wrapped function, lowering and compiling as needed."""
631637
raise NotImplementedError
632638

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def generate_proto(source):
6767
'numpy>=1.20',
6868
'opt_einsum',
6969
'scipy>=1.5',
70+
'typing_extensions>=4.5.0',
7071
],
7172
extras_require={
7273
# Minimum jaxlib version; used in testing.

0 commit comments

Comments
 (0)