diff --git a/RELEASES.md b/RELEASES.md index ccb9b97d2..dc5732562 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,7 @@ #### Closed issues - Fix missing cython file in MANIFEST.in (PR #763) +- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770) ## 0.9.6 diff --git a/ot/backend.py b/ot/backend.py index 64b5a88cf..f14da588b 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -119,7 +119,7 @@ import jax import jax.numpy as jnp import jax.scipy.special as jspecial - from jax.lib import xla_bridge + from jax.extend.backend import get_backend as _jax_get_backend jax_type = jax.numpy.ndarray jax_new_version = float(".".join(jax.__version__.split(".")[1:])) > 4.24 @@ -1509,7 +1509,7 @@ def __init__(self): self.__type_list__ = [] # available_devices = jax.devices("cpu") available_devices = [] - if xla_bridge.get_backend().platform == "gpu": + if _jax_get_backend().platform == "gpu": available_devices += jax.devices("gpu") for d in available_devices: self.__type_list__ += [