2222from __future__ import annotations
2323
2424import collections
25+ from contextlib import contextmanager , ExitStack
2526import functools
2627from functools import partial
2728import inspect
2829from typing import (Any , Callable , Generator , Hashable , Iterable , List , Literal ,
29- NamedTuple , Optional , Sequence , Tuple , TypeVar , Union , overload )
30+ NamedTuple , Optional , Sequence , Tuple , TypeVar , Union ,
31+ overload )
3032
3133import numpy as np
32- from contextlib import contextmanager , ExitStack
34+ from typing_extensions import ParamSpec
3335
3436import jax
3537from jax ._src import linear_util as lu
105107F = TypeVar ("F" , bound = Callable )
106108T = TypeVar ("T" )
107109U = TypeVar ("U" )
110+ V_co = TypeVar ("V_co" , covariant = True )
111+ P = ParamSpec ("P" )
112+
108113
109114map , unsafe_map = safe_map , map
110115zip , unsafe_zip = safe_zip , zip
@@ -155,7 +160,7 @@ def _update_debug_special_thread_local(_):
155160
156161
157162def jit (
158- fun : Callable ,
163+ fun : Callable [ P , V_co ] ,
159164 * ,
160165 static_argnums : Union [int , Iterable [int ], None ] = None ,
161166 static_argnames : Union [str , Iterable [str ], None ] = None ,
@@ -165,7 +170,7 @@ def jit(
165170 inline : bool = False ,
166171 keep_unused : bool = False ,
167172 abstracted_axes : Optional [Any ] = None ,
168- ) -> stages .Wrapped :
173+ ) -> stages .Wrapped [ P , V_co ] :
169174 """Sets up ``fun`` for just-in-time compilation with XLA.
170175
171176 Args:
@@ -339,7 +344,7 @@ def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
339344PytreeOfAbstractedAxesSpec = Any
340345
341346def _python_jit (
342- fun : Callable ,
347+ fun : Callable [ P , V_co ] ,
343348 * ,
344349 static_argnums : Tuple [int , ...],
345350 static_argnames : Tuple [str , ...],
@@ -349,7 +354,7 @@ def _python_jit(
349354 inline : bool ,
350355 keep_unused : bool ,
351356 abstracted_axes : Optional [PytreeOfAbstractedAxesSpec ],
352- ) -> stages .Wrapped :
357+ ) -> stages .Wrapped [ P , V_co ] :
353358 @wraps (fun )
354359 @api_boundary
355360 def f_jitted (* args , ** kwargs ):
@@ -483,7 +488,7 @@ def _device_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
483488 return None
484489
485490def _cpp_jit (
486- fun : Callable ,
491+ fun : Callable [ P , V_co ] ,
487492 * ,
488493 static_argnums : Tuple [int , ...],
489494 static_argnames : Tuple [str , ...],
@@ -492,7 +497,7 @@ def _cpp_jit(
492497 donate_argnums : Tuple [int , ...],
493498 inline : bool ,
494499 keep_unused : bool ,
495- ) -> stages .Wrapped :
500+ ) -> stages .Wrapped [ P , V_co ] :
496501 # An implementation of `jit` that tries to do as much as possible in C++.
497502 # The goal of this function is to speed up the time it takes to process the
498503 # arguments, find the correct C++ executable, start the transfer of arguments
@@ -2064,7 +2069,7 @@ def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums,
20642069
20652070
20662071def _python_pmap (
2067- fun : Callable ,
2072+ fun : Callable [ P , V_co ] ,
20682073 axis_name : Optional [AxisName ] = None ,
20692074 * ,
20702075 in_axes = 0 ,
@@ -2075,7 +2080,7 @@ def _python_pmap(
20752080 axis_size : Optional [int ] = None ,
20762081 donate_argnums : Union [int , Iterable [int ]] = (),
20772082 global_arg_shapes : Optional [Tuple [Tuple [int , ...], ...]] = None ,
2078- ) -> stages .Wrapped :
2083+ ) -> stages .Wrapped [ P , V_co ] :
20792084 """The Python only implementation."""
20802085 axis_name , static_broadcasted_tuple , donate_tuple = _shared_code_pmap (
20812086 fun , axis_name , static_broadcasted_argnums , donate_argnums , in_axes ,
0 commit comments