Skip to content

Commit b743a69

Browse files
committed
Make typing_extensions a dev-dependency
1 parent 0e36e1e commit b743a69

File tree

4 files changed

+19
-7
lines changed

4 files changed

+19
-7
lines changed

jax/_src/api.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@
2828
from functools import partial
2929
import inspect
3030
import math
31-
from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast, overload
31+
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, cast,
32+
overload, TYPE_CHECKING)
3233
import weakref
3334

3435
import numpy as np
35-
from typing_extensions import ParamSpec
36+
if TYPE_CHECKING:
37+
from typing_extensions import ParamSpec
3638

3739
from jax._src import linear_util as lu
3840
from jax._src import stages
@@ -96,7 +98,10 @@
9698
T = TypeVar("T")
9799
U = TypeVar("U")
98100
V_co = TypeVar("V_co", covariant=True)
99-
P = ParamSpec("P")
101+
if TYPE_CHECKING:
102+
P = ParamSpec("P")
103+
else:
104+
P = TypeVar("P")
100105

101106

102107
map, unsafe_map = safe_map, map

jax/_src/pjit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,8 @@ def ax_leaf(l):
589589
V_co = TypeVar("V_co", covariant=True)
590590
if TYPE_CHECKING:
591591
P = ParamSpec("P")
592+
else:
593+
P = TypeVar("P")
592594

593595

594596
class JitWrapped(stages.Wrapped[P, V_co], Generic[P, V_co]):

jax/_src/stages.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@
3232

3333
from collections.abc import Sequence
3434
from dataclasses import dataclass
35-
from typing import Any, Generic, NamedTuple, Protocol, TypeVar, Union
35+
from typing import (Any, Generic, NamedTuple, Protocol, TypeVar, Union,
36+
TYPE_CHECKING)
3637
import warnings
3738

38-
from typing_extensions import ParamSpec
39+
if TYPE_CHECKING:
40+
from typing_extensions import ParamSpec
3941

4042
import jax
4143

@@ -710,7 +712,10 @@ def cost_analysis(self) -> Any | None:
710712

711713

712714
V_co = TypeVar("V_co", covariant=True)
713-
P = ParamSpec("P")
715+
if TYPE_CHECKING:
716+
P = ParamSpec("P")
717+
else:
718+
P = TypeVar("P")
714719

715720

716721
class Wrapped(Protocol, Generic[P, V_co]):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def generate_proto(source):
8484
# Python versions < 3.10. Can be dropped when 3.10 is the minimum
8585
# required Python version.
8686
'importlib_metadata>=4.6;python_version<"3.10"',
87-
'typing_extensions>=4.5.0',
8887
],
8988
extras_require={
89+
'dev': ['typing_extensions>=4.8.0'],
9090
# Minimum jaxlib version; used in testing.
9191
'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'],
9292

0 commit comments

Comments
 (0)