Skip to content

Functions decorated with @util._wraps(...) lose their type annotations. #16863

@packquickly

Description

@packquickly

Description

I've noticed that the functions in _src/numpy/lax_numpy.py which are decorated with @util._wraps(...) lose their type annotation.

For example, jnp.clip is decorated with @util._wraps(np.clip, skip_params=['out']) and for the following code:

import jax.numpy as jnp
from typing import TYPE_CHECKING

x = jnp.arange(17)
x_clipped = jnp.clip(x, a_max=14)

if TYPE_CHECKING:
    reveal_type(x) # Array
    reveal_type(x_clipped) # Unknown or Any

mypy says x_clipped has type Any, and Pyright says it has type Unknown, it should be type jax._src.basearray.Array. In lax_numpy the function clip is properly annotated before this decorator.

What jax/jaxlib version are you using?

0.4.13

Which accelerator(s) are you using?

CPU

Additional system info

python 3.10.12 on MacOS

NVIDIA GPU info

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions