Description
I've noticed that the functions in _src/numpy/lax_numpy.py which are decorated with @util._wraps(...) lose their type annotation.
For example, jnp.clip is decorated with @util._wraps(np.clip, skip_params=['out']) and for the following code:
import jax.numpy as jnp
from typing import TYPE_CHECKING
x = jnp.arange(17)
x_clipped = jnp.clip(x, a_max=14)
if TYPE_CHECKING:
reveal_type(x) # Array
reveal_type(x_clipped) # Unknown or Any
mypy says x_clipped has type Any, and Pyright says it has type Unknown, it should be type jax._src.basearray.Array. In lax_numpy the function clip is properly annotated before this decorator.
What jax/jaxlib version are you using?
0.4.13
Which accelerator(s) are you using?
CPU
Additional system info
python 3.10.12 on MacOS
NVIDIA GPU info
No response