Skip to content

Commit 25c7d88

Browse files
Merge branch 'master' into fix-line-search-zero-cost
2 parents 96b3902 + 5ab00dd commit 25c7d88

File tree

8 files changed

+263
-93
lines changed

8 files changed

+263
-93
lines changed

RELEASES.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
# Releases
22

3+
## 0.9.2dev
4+
5+
#### New features
6+
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)
7+
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
8+
9+
#### Closed issues
10+
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
11+
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
12+
13+
314
## 0.9.1
415
*August 2023*
516

docs/source/quickstart.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,13 @@ List of compatible Backends
961961
- `Tensorflow <https://www.tensorflow.org/>`_ (all outputs differentiable w.r.t. inputs)
962962
- `Cupy <https://cupy.dev/>`_ (no differentiation, GPU only)
963963

964+
The library automatically detects which backends are available for use. A backend
965+
is instantiated lazily only when necessary to prevent unwarranted GPU memory allocations.
966+
You can also disable the import of a specific backend library (e.g., to accelerate
967+
loading of `ot` library) using the environment variable `POT_BACKEND_DISABLE_<NAME>` with <NAME> in (TORCH,TENSORFLOW,CUPY,JAX).
968+
For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=1`.
969+
It's important to note that the `numpy` backend cannot be disabled.
970+
964971

965972
List of compatible modules
966973
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

ot/backend.py

Lines changed: 124 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -87,43 +87,67 @@
8787
# License: MIT License
8888

8989
import numpy as np
90+
import os
9091
import scipy
9192
import scipy.linalg
92-
import scipy.special as special
9393
from scipy.sparse import issparse, coo_matrix, csr_matrix
94-
import warnings
94+
import scipy.special as special
9595
import time
96+
import warnings
97+
9698

97-
try:
98-
import torch
99-
torch_type = torch.Tensor
100-
except ImportError:
99+
DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
100+
DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
101+
DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY'
102+
DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW'
103+
104+
105+
if not os.environ.get(DISABLE_TORCH_KEY, False):
106+
try:
107+
import torch
108+
torch_type = torch.Tensor
109+
except ImportError:
110+
torch = False
111+
torch_type = float
112+
else:
101113
torch = False
102114
torch_type = float
103115

104-
try:
105-
import jax
106-
import jax.numpy as jnp
107-
import jax.scipy.special as jspecial
108-
from jax.lib import xla_bridge
109-
jax_type = jax.numpy.ndarray
110-
except ImportError:
116+
if not os.environ.get(DISABLE_JAX_KEY, False):
117+
try:
118+
import jax
119+
import jax.numpy as jnp
120+
import jax.scipy.special as jspecial
121+
from jax.lib import xla_bridge
122+
jax_type = jax.numpy.ndarray
123+
except ImportError:
124+
jax = False
125+
jax_type = float
126+
else:
111127
jax = False
112128
jax_type = float
113129

114-
try:
115-
import cupy as cp
116-
import cupyx
117-
cp_type = cp.ndarray
118-
except ImportError:
130+
if not os.environ.get(DISABLE_CUPY_KEY, False):
131+
try:
132+
import cupy as cp
133+
import cupyx
134+
cp_type = cp.ndarray
135+
except ImportError:
136+
cp = False
137+
cp_type = float
138+
else:
119139
cp = False
120140
cp_type = float
121141

122-
try:
123-
import tensorflow as tf
124-
import tensorflow.experimental.numpy as tnp
125-
tf_type = tf.Tensor
126-
except ImportError:
142+
if not os.environ.get(DISABLE_TF_KEY, False):
143+
try:
144+
import tensorflow as tf
145+
import tensorflow.experimental.numpy as tnp
146+
tf_type = tf.Tensor
147+
except ImportError:
148+
tf = False
149+
tf_type = float
150+
else:
127151
tf = False
128152
tf_type = float
129153

@@ -132,40 +156,69 @@
132156

133157

134158
# Mapping between argument types and the existing backend
135-
_BACKENDS = []
159+
_BACKEND_IMPLEMENTATIONS = []
160+
_BACKENDS = {}
136161

137162

138-
def register_backend(backend):
139-
_BACKENDS.append(backend)
163+
def _register_backend_implementation(backend_impl):
164+
_BACKEND_IMPLEMENTATIONS.append(backend_impl)
140165

141166

142-
def get_backend_list():
143-
"""Returns the list of available backends"""
144-
return _BACKENDS
167+
def _get_backend_instance(backend_impl):
168+
if backend_impl.__name__ not in _BACKENDS:
169+
_BACKENDS[backend_impl.__name__] = backend_impl()
170+
return _BACKENDS[backend_impl.__name__]
145171

146172

147-
def _check_args_backend(backend, args):
148-
is_instance = set(isinstance(a, backend.__type__) for a in args)
173+
def _check_args_backend(backend_impl, args):
174+
is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args)
149175
# check that all arguments matched or not the type
150176
if len(is_instance) == 1:
151177
return is_instance.pop()
152178

153-
# Oterwise return an error
154-
raise ValueError(str_type_error.format([type(a) for a in args]))
179+
# Otherwise return an error
180+
raise ValueError(str_type_error.format([type(arg) for arg in args]))
181+
182+
183+
def get_backend_list():
184+
"""Returns instances of all available backends.
185+
186+
Note that the function forces all detected implementations
187+
to be instantiated even if specific backend was not use before.
188+
Be careful as instantiation of the backend might lead to side effects,
189+
like GPU memory pre-allocation. See the documentation for more details.
190+
If you only need to know which implementations are available,
191+
use `:py:func:`ot.backend.get_available_backend_implementations`,
192+
which does not force instance of the backend object to be created.
193+
"""
194+
return [
195+
_get_backend_instance(backend_impl)
196+
for backend_impl
197+
in get_available_backend_implementations()
198+
]
199+
200+
201+
def get_available_backend_implementations():
202+
"""Returns the list of available backend implementations."""
203+
return _BACKEND_IMPLEMENTATIONS
155204

156205

157206
def get_backend(*args):
158207
"""Returns the proper backend for a list of input arrays
159208
209+
Accepts None entries in the arguments, and ignores them
210+
160211
Also raises TypeError if all arrays are not from the same backend
161212
"""
213+
args = [arg for arg in args if arg is not None] # exclude None entries
214+
162215
# check that some arrays given
163216
if not len(args) > 0:
164-
raise ValueError(" The function takes at least one parameter")
217+
raise ValueError(" The function takes at least one (non-None) parameter")
165218

166-
for backend in _BACKENDS:
167-
if _check_args_backend(backend, args):
168-
return backend
219+
for backend_impl in _BACKEND_IMPLEMENTATIONS:
220+
if _check_args_backend(backend_impl, args):
221+
return _get_backend_instance(backend_impl)
169222

170223
raise ValueError("Unknown type of non implemented backend.")
171224

@@ -407,7 +460,7 @@ def power(self, a, exponents):
407460
"""
408461
raise NotImplementedError()
409462

410-
def norm(self, a):
463+
def norm(self, a, axis=None, keepdims=False):
411464
r"""
412465
Computes the matrix frobenius norm.
413466
@@ -627,7 +680,7 @@ def diag(self, a, k=0):
627680
"""
628681
raise NotImplementedError()
629682

630-
def unique(self, a):
683+
def unique(self, a, return_inverse=False):
631684
r"""
632685
Finds unique elements of given tensor.
633686
@@ -1087,8 +1140,8 @@ def sqrt(self, a):
10871140
def power(self, a, exponents):
10881141
return np.power(a, exponents)
10891142

1090-
def norm(self, a):
1091-
return np.sqrt(np.sum(np.square(a)))
1143+
def norm(self, a, axis=None, keepdims=False):
1144+
return np.linalg.norm(a, axis=axis, keepdims=keepdims)
10921145

10931146
def any(self, a):
10941147
return np.any(a)
@@ -1164,8 +1217,8 @@ def meshgrid(self, a, b):
11641217
def diag(self, a, k=0):
11651218
return np.diag(a, k)
11661219

1167-
def unique(self, a):
1168-
return np.unique(a)
1220+
def unique(self, a, return_inverse=False):
1221+
return np.unique(a, return_inverse=return_inverse)
11691222

11701223
def logsumexp(self, a, axis=None):
11711224
return special.logsumexp(a, axis=axis)
@@ -1337,7 +1390,7 @@ def matmul(self, a, b):
13371390
return np.matmul(a, b)
13381391

13391392

1340-
register_backend(NumpyBackend())
1393+
_register_backend_implementation(NumpyBackend)
13411394

13421395

13431396
class JaxBackend(Backend):
@@ -1461,8 +1514,8 @@ def sqrt(self, a):
14611514
def power(self, a, exponents):
14621515
return jnp.power(a, exponents)
14631516

1464-
def norm(self, a):
1465-
return jnp.sqrt(jnp.sum(jnp.square(a)))
1517+
def norm(self, a, axis=None, keepdims=False):
1518+
return jnp.linalg.norm(a, axis=axis, keepdims=keepdims)
14661519

14671520
def any(self, a):
14681521
return jnp.any(a)
@@ -1535,8 +1588,8 @@ def meshgrid(self, a, b):
15351588
def diag(self, a, k=0):
15361589
return jnp.diag(a, k)
15371590

1538-
def unique(self, a):
1539-
return jnp.unique(a)
1591+
def unique(self, a, return_inverse=False):
1592+
return jnp.unique(a, return_inverse=return_inverse)
15401593

15411594
def logsumexp(self, a, axis=None):
15421595
return jspecial.logsumexp(a, axis=axis)
@@ -1706,7 +1759,7 @@ def matmul(self, a, b):
17061759

17071760
if jax:
17081761
# Only register jax backend if it is installed
1709-
register_backend(JaxBackend())
1762+
_register_backend_implementation(JaxBackend)
17101763

17111764

17121765
class TorchBackend(Backend):
@@ -1881,8 +1934,8 @@ def sqrt(self, a):
18811934
def power(self, a, exponents):
18821935
return torch.pow(a, exponents)
18831936

1884-
def norm(self, a):
1885-
return torch.sqrt(torch.sum(torch.square(a)))
1937+
def norm(self, a, axis=None, keepdims=False):
1938+
return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims)
18861939

18871940
def any(self, a):
18881941
return torch.any(a)
@@ -1986,8 +2039,8 @@ def meshgrid(self, a, b):
19862039
def diag(self, a, k=0):
19872040
return torch.diag(a, diagonal=k)
19882041

1989-
def unique(self, a):
1990-
return torch.unique(a)
2042+
def unique(self, a, return_inverse=False):
2043+
return torch.unique(a, return_inverse=return_inverse)
19912044

19922045
def logsumexp(self, a, axis=None):
19932046
if axis is not None:
@@ -2189,7 +2242,7 @@ def matmul(self, a, b):
21892242

21902243
if torch:
21912244
# Only register torch backend if it is installed
2192-
register_backend(TorchBackend())
2245+
_register_backend_implementation(TorchBackend)
21932246

21942247

21952248
class CupyBackend(Backend): # pragma: no cover
@@ -2306,8 +2359,8 @@ def power(self, a, exponents):
23062359
def dot(self, a, b):
23072360
return cp.dot(a, b)
23082361

2309-
def norm(self, a):
2310-
return cp.sqrt(cp.sum(cp.square(a)))
2362+
def norm(self, a, axis=None, keepdims=False):
2363+
return cp.linalg.norm(a, axis=axis, keepdims=keepdims)
23112364

23122365
def any(self, a):
23132366
return cp.any(a)
@@ -2383,8 +2436,8 @@ def meshgrid(self, a, b):
23832436
def diag(self, a, k=0):
23842437
return cp.diag(a, k)
23852438

2386-
def unique(self, a):
2387-
return cp.unique(a)
2439+
def unique(self, a, return_inverse=False):
2440+
return cp.unique(a, return_inverse=return_inverse)
23882441

23892442
def logsumexp(self, a, axis=None):
23902443
# Taken from
@@ -2582,7 +2635,7 @@ def matmul(self, a, b):
25822635

25832636
if cp:
25842637
# Only register cp backend if it is installed
2585-
register_backend(CupyBackend())
2638+
_register_backend_implementation(CupyBackend)
25862639

25872640

25882641
class TensorflowBackend(Backend):
@@ -2717,8 +2770,8 @@ def sqrt(self, a):
27172770
def power(self, a, exponents):
27182771
return tnp.power(a, exponents)
27192772

2720-
def norm(self, a):
2721-
return tf.math.reduce_euclidean_norm(a)
2773+
def norm(self, a, axis=None, keepdims=False):
2774+
return tf.math.reduce_euclidean_norm(a, axis=axis, keepdims=keepdims)
27222775

27232776
def any(self, a):
27242777
return tnp.any(a)
@@ -2790,8 +2843,15 @@ def meshgrid(self, a, b):
27902843
def diag(self, a, k=0):
27912844
return tnp.diag(a, k)
27922845

2793-
def unique(self, a):
2794-
return tf.sort(tf.unique(tf.reshape(a, [-1]))[0])
2846+
def unique(self, a, return_inverse=False):
2847+
y, idx = tf.unique(tf.reshape(a, [-1]))
2848+
sort_idx = tf.argsort(y)
2849+
y_prime = tf.gather(y, sort_idx)
2850+
if return_inverse:
2851+
inv_sort_idx = tf.math.invert_permutation(sort_idx)
2852+
return y_prime, tf.gather(inv_sort_idx, idx)
2853+
else:
2854+
return y_prime
27952855

27962856
def logsumexp(self, a, axis=None):
27972857
return tf.math.reduce_logsumexp(a, axis=axis)
@@ -2995,4 +3055,4 @@ def matmul(self, a, b):
29953055

29963056
if tf:
29973057
# Only register tensorflow backend if it is installed
2998-
register_backend(TensorflowBackend())
3058+
_register_backend_implementation(TensorflowBackend)

0 commit comments

Comments
 (0)