Skip to content

Commit f5781f8

Browse files
authored
try fix (#771)
1 parent 5a7d086 commit f5781f8

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#### Closed issues
88
- Fix missing cython file in MANIFEST.in (PR #763)
9+
- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
910

1011
## 0.9.6
1112

ot/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
import jax
120120
import jax.numpy as jnp
121121
import jax.scipy.special as jspecial
122-
from jax.lib import xla_bridge
122+
from jax.extend.backend import get_backend as _jax_get_backend
123123

124124
jax_type = jax.numpy.ndarray
125125
jax_new_version = float(".".join(jax.__version__.split(".")[1:])) > 4.24
@@ -1509,7 +1509,7 @@ def __init__(self):
15091509
self.__type_list__ = []
15101510
# available_devices = jax.devices("cpu")
15111511
available_devices = []
1512-
if xla_bridge.get_backend().platform == "gpu":
1512+
if _jax_get_backend().platform == "gpu":
15131513
available_devices += jax.devices("gpu")
15141514
for d in available_devices:
15151515
self.__type_list__ += [

0 commit comments

Comments
 (0)