File tree Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change 66
77#### Closed issues
88- Fix missing cython file in MANIFEST.in (PR #763 )
9+ - Fix deprecated JAX function in ` ot.backend.JaxBackend ` (PR #771 , Issue #770 )
910
1011## 0.9.6
1112
Original file line number Diff line number Diff line change 119119 import jax
120120 import jax .numpy as jnp
121121 import jax .scipy .special as jspecial
122- from jax .lib import xla_bridge
122+ from jax .extend . backend import get_backend as _jax_get_backend
123123
124124 jax_type = jax .numpy .ndarray
125125 jax_new_version = float ("." .join (jax .__version__ .split ("." )[1 :])) > 4.24
@@ -1509,7 +1509,7 @@ def __init__(self):
15091509 self .__type_list__ = []
15101510 # available_devices = jax.devices("cpu")
15111511 available_devices = []
1512- if xla_bridge . get_backend ().platform == "gpu" :
1512+ if _jax_get_backend ().platform == "gpu" :
15131513 available_devices += jax .devices ("gpu" )
15141514 for d in available_devices :
15151515 self .__type_list__ += [
You can’t perform that action at this time.
0 commit comments