|
87 | 87 | # License: MIT License
|
88 | 88 |
|
89 | 89 | import numpy as np
|
| 90 | +import os |
90 | 91 | import scipy
|
91 | 92 | import scipy.linalg
|
92 |
| -import scipy.special as special |
93 | 93 | from scipy.sparse import issparse, coo_matrix, csr_matrix
|
94 |
| -import warnings |
| 94 | +import scipy.special as special |
95 | 95 | import time
|
| 96 | +import warnings |
| 97 | + |
96 | 98 |
|
97 |
| -try: |
98 |
| - import torch |
99 |
| - torch_type = torch.Tensor |
100 |
| -except ImportError: |
| 99 | +DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH' |
| 100 | +DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX' |
| 101 | +DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY' |
| 102 | +DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW' |
| 103 | + |
| 104 | + |
| 105 | +if not os.environ.get(DISABLE_TORCH_KEY, False): |
| 106 | + try: |
| 107 | + import torch |
| 108 | + torch_type = torch.Tensor |
| 109 | + except ImportError: |
| 110 | + torch = False |
| 111 | + torch_type = float |
| 112 | +else: |
101 | 113 | torch = False
|
102 | 114 | torch_type = float
|
103 | 115 |
|
104 |
| -try: |
105 |
| - import jax |
106 |
| - import jax.numpy as jnp |
107 |
| - import jax.scipy.special as jspecial |
108 |
| - from jax.lib import xla_bridge |
109 |
| - jax_type = jax.numpy.ndarray |
110 |
| -except ImportError: |
| 116 | +if not os.environ.get(DISABLE_JAX_KEY, False): |
| 117 | + try: |
| 118 | + import jax |
| 119 | + import jax.numpy as jnp |
| 120 | + import jax.scipy.special as jspecial |
| 121 | + from jax.lib import xla_bridge |
| 122 | + jax_type = jax.numpy.ndarray |
| 123 | + except ImportError: |
| 124 | + jax = False |
| 125 | + jax_type = float |
| 126 | +else: |
111 | 127 | jax = False
|
112 | 128 | jax_type = float
|
113 | 129 |
|
114 |
| -try: |
115 |
| - import cupy as cp |
116 |
| - import cupyx |
117 |
| - cp_type = cp.ndarray |
118 |
| -except ImportError: |
| 130 | +if not os.environ.get(DISABLE_CUPY_KEY, False): |
| 131 | + try: |
| 132 | + import cupy as cp |
| 133 | + import cupyx |
| 134 | + cp_type = cp.ndarray |
| 135 | + except ImportError: |
| 136 | + cp = False |
| 137 | + cp_type = float |
| 138 | +else: |
119 | 139 | cp = False
|
120 | 140 | cp_type = float
|
121 | 141 |
|
122 |
| -try: |
123 |
| - import tensorflow as tf |
124 |
| - import tensorflow.experimental.numpy as tnp |
125 |
| - tf_type = tf.Tensor |
126 |
| -except ImportError: |
| 142 | +if not os.environ.get(DISABLE_TF_KEY, False): |
| 143 | + try: |
| 144 | + import tensorflow as tf |
| 145 | + import tensorflow.experimental.numpy as tnp |
| 146 | + tf_type = tf.Tensor |
| 147 | + except ImportError: |
| 148 | + tf = False |
| 149 | + tf_type = float |
| 150 | +else: |
127 | 151 | tf = False
|
128 | 152 | tf_type = float
|
129 | 153 |
|
|
132 | 156 |
|
133 | 157 |
|
134 | 158 | # Mapping between argument types and the existing backend
|
135 |
| -_BACKENDS = [] |
| 159 | +_BACKEND_IMPLEMENTATIONS = [] |
| 160 | +_BACKENDS = {} |
136 | 161 |
|
137 | 162 |
|
138 |
| -def register_backend(backend): |
139 |
| - _BACKENDS.append(backend) |
| 163 | +def _register_backend_implementation(backend_impl): |
| 164 | + _BACKEND_IMPLEMENTATIONS.append(backend_impl) |
140 | 165 |
|
141 | 166 |
|
142 |
| -def get_backend_list(): |
143 |
| - """Returns the list of available backends""" |
144 |
| - return _BACKENDS |
| 167 | +def _get_backend_instance(backend_impl): |
| 168 | + if backend_impl.__name__ not in _BACKENDS: |
| 169 | + _BACKENDS[backend_impl.__name__] = backend_impl() |
| 170 | + return _BACKENDS[backend_impl.__name__] |
145 | 171 |
|
146 | 172 |
|
147 |
| -def _check_args_backend(backend, args): |
148 |
| - is_instance = set(isinstance(a, backend.__type__) for a in args) |
| 173 | +def _check_args_backend(backend_impl, args): |
| 174 | + is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args) |
149 | 175 | # check that all arguments matched or not the type
|
150 | 176 | if len(is_instance) == 1:
|
151 | 177 | return is_instance.pop()
|
152 | 178 |
|
153 |
| - # Oterwise return an error |
154 |
| - raise ValueError(str_type_error.format([type(a) for a in args])) |
| 179 | + # Otherwise return an error |
| 180 | + raise ValueError(str_type_error.format([type(arg) for arg in args])) |
| 181 | + |
| 182 | + |
| 183 | +def get_backend_list(): |
| 184 | + """Returns instances of all available backends. |
| 185 | +
|
| 186 | + Note that the function forces all detected implementations |
| 187 | + to be instantiated even if specific backend was not use before. |
| 188 | + Be careful as instantiation of the backend might lead to side effects, |
| 189 | + like GPU memory pre-allocation. See the documentation for more details. |
| 190 | + If you only need to know which implementations are available, |
| 191 | + use `:py:func:`ot.backend.get_available_backend_implementations`, |
| 192 | + which does not force instance of the backend object to be created. |
| 193 | + """ |
| 194 | + return [ |
| 195 | + _get_backend_instance(backend_impl) |
| 196 | + for backend_impl |
| 197 | + in get_available_backend_implementations() |
| 198 | + ] |
| 199 | + |
| 200 | + |
| 201 | +def get_available_backend_implementations(): |
| 202 | + """Returns the list of available backend implementations.""" |
| 203 | + return _BACKEND_IMPLEMENTATIONS |
155 | 204 |
|
156 | 205 |
|
157 | 206 | def get_backend(*args):
|
158 | 207 | """Returns the proper backend for a list of input arrays
|
159 | 208 |
|
| 209 | + Accepts None entries in the arguments, and ignores them |
| 210 | +
|
160 | 211 | Also raises TypeError if all arrays are not from the same backend
|
161 | 212 | """
|
| 213 | + args = [arg for arg in args if arg is not None] # exclude None entries |
| 214 | + |
162 | 215 | # check that some arrays given
|
163 | 216 | if not len(args) > 0:
|
164 |
| - raise ValueError(" The function takes at least one parameter") |
| 217 | + raise ValueError(" The function takes at least one (non-None) parameter") |
165 | 218 |
|
166 |
| - for backend in _BACKENDS: |
167 |
| - if _check_args_backend(backend, args): |
168 |
| - return backend |
| 219 | + for backend_impl in _BACKEND_IMPLEMENTATIONS: |
| 220 | + if _check_args_backend(backend_impl, args): |
| 221 | + return _get_backend_instance(backend_impl) |
169 | 222 |
|
170 | 223 | raise ValueError("Unknown type of non implemented backend.")
|
171 | 224 |
|
@@ -407,7 +460,7 @@ def power(self, a, exponents):
|
407 | 460 | """
|
408 | 461 | raise NotImplementedError()
|
409 | 462 |
|
410 |
| - def norm(self, a): |
| 463 | + def norm(self, a, axis=None, keepdims=False): |
411 | 464 | r"""
|
412 | 465 | Computes the matrix frobenius norm.
|
413 | 466 |
|
@@ -627,7 +680,7 @@ def diag(self, a, k=0):
|
627 | 680 | """
|
628 | 681 | raise NotImplementedError()
|
629 | 682 |
|
630 |
| - def unique(self, a): |
| 683 | + def unique(self, a, return_inverse=False): |
631 | 684 | r"""
|
632 | 685 | Finds unique elements of given tensor.
|
633 | 686 |
|
@@ -1087,8 +1140,8 @@ def sqrt(self, a):
|
1087 | 1140 | def power(self, a, exponents):
|
1088 | 1141 | return np.power(a, exponents)
|
1089 | 1142 |
|
1090 |
| - def norm(self, a): |
1091 |
| - return np.sqrt(np.sum(np.square(a))) |
| 1143 | + def norm(self, a, axis=None, keepdims=False): |
| 1144 | + return np.linalg.norm(a, axis=axis, keepdims=keepdims) |
1092 | 1145 |
|
1093 | 1146 | def any(self, a):
|
1094 | 1147 | return np.any(a)
|
@@ -1164,8 +1217,8 @@ def meshgrid(self, a, b):
|
1164 | 1217 | def diag(self, a, k=0):
|
1165 | 1218 | return np.diag(a, k)
|
1166 | 1219 |
|
1167 |
| - def unique(self, a): |
1168 |
| - return np.unique(a) |
| 1220 | + def unique(self, a, return_inverse=False): |
| 1221 | + return np.unique(a, return_inverse=return_inverse) |
1169 | 1222 |
|
1170 | 1223 | def logsumexp(self, a, axis=None):
|
1171 | 1224 | return special.logsumexp(a, axis=axis)
|
@@ -1337,7 +1390,7 @@ def matmul(self, a, b):
|
1337 | 1390 | return np.matmul(a, b)
|
1338 | 1391 |
|
1339 | 1392 |
|
1340 |
| -register_backend(NumpyBackend()) |
| 1393 | +_register_backend_implementation(NumpyBackend) |
1341 | 1394 |
|
1342 | 1395 |
|
1343 | 1396 | class JaxBackend(Backend):
|
@@ -1461,8 +1514,8 @@ def sqrt(self, a):
|
1461 | 1514 | def power(self, a, exponents):
|
1462 | 1515 | return jnp.power(a, exponents)
|
1463 | 1516 |
|
1464 |
| - def norm(self, a): |
1465 |
| - return jnp.sqrt(jnp.sum(jnp.square(a))) |
| 1517 | + def norm(self, a, axis=None, keepdims=False): |
| 1518 | + return jnp.linalg.norm(a, axis=axis, keepdims=keepdims) |
1466 | 1519 |
|
1467 | 1520 | def any(self, a):
|
1468 | 1521 | return jnp.any(a)
|
@@ -1535,8 +1588,8 @@ def meshgrid(self, a, b):
|
1535 | 1588 | def diag(self, a, k=0):
|
1536 | 1589 | return jnp.diag(a, k)
|
1537 | 1590 |
|
1538 |
| - def unique(self, a): |
1539 |
| - return jnp.unique(a) |
| 1591 | + def unique(self, a, return_inverse=False): |
| 1592 | + return jnp.unique(a, return_inverse=return_inverse) |
1540 | 1593 |
|
1541 | 1594 | def logsumexp(self, a, axis=None):
|
1542 | 1595 | return jspecial.logsumexp(a, axis=axis)
|
@@ -1706,7 +1759,7 @@ def matmul(self, a, b):
|
1706 | 1759 |
|
1707 | 1760 | if jax:
|
1708 | 1761 | # Only register jax backend if it is installed
|
1709 |
| - register_backend(JaxBackend()) |
| 1762 | + _register_backend_implementation(JaxBackend) |
1710 | 1763 |
|
1711 | 1764 |
|
1712 | 1765 | class TorchBackend(Backend):
|
@@ -1881,8 +1934,8 @@ def sqrt(self, a):
|
1881 | 1934 | def power(self, a, exponents):
|
1882 | 1935 | return torch.pow(a, exponents)
|
1883 | 1936 |
|
1884 |
| - def norm(self, a): |
1885 |
| - return torch.sqrt(torch.sum(torch.square(a))) |
| 1937 | + def norm(self, a, axis=None, keepdims=False): |
| 1938 | + return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims) |
1886 | 1939 |
|
1887 | 1940 | def any(self, a):
|
1888 | 1941 | return torch.any(a)
|
@@ -1986,8 +2039,8 @@ def meshgrid(self, a, b):
|
1986 | 2039 | def diag(self, a, k=0):
|
1987 | 2040 | return torch.diag(a, diagonal=k)
|
1988 | 2041 |
|
1989 |
| - def unique(self, a): |
1990 |
| - return torch.unique(a) |
| 2042 | + def unique(self, a, return_inverse=False): |
| 2043 | + return torch.unique(a, return_inverse=return_inverse) |
1991 | 2044 |
|
1992 | 2045 | def logsumexp(self, a, axis=None):
|
1993 | 2046 | if axis is not None:
|
@@ -2189,7 +2242,7 @@ def matmul(self, a, b):
|
2189 | 2242 |
|
2190 | 2243 | if torch:
|
2191 | 2244 | # Only register torch backend if it is installed
|
2192 |
| - register_backend(TorchBackend()) |
| 2245 | + _register_backend_implementation(TorchBackend) |
2193 | 2246 |
|
2194 | 2247 |
|
2195 | 2248 | class CupyBackend(Backend): # pragma: no cover
|
@@ -2306,8 +2359,8 @@ def power(self, a, exponents):
|
2306 | 2359 | def dot(self, a, b):
|
2307 | 2360 | return cp.dot(a, b)
|
2308 | 2361 |
|
2309 |
| - def norm(self, a): |
2310 |
| - return cp.sqrt(cp.sum(cp.square(a))) |
| 2362 | + def norm(self, a, axis=None, keepdims=False): |
| 2363 | + return cp.linalg.norm(a, axis=axis, keepdims=keepdims) |
2311 | 2364 |
|
2312 | 2365 | def any(self, a):
|
2313 | 2366 | return cp.any(a)
|
@@ -2383,8 +2436,8 @@ def meshgrid(self, a, b):
|
2383 | 2436 | def diag(self, a, k=0):
|
2384 | 2437 | return cp.diag(a, k)
|
2385 | 2438 |
|
2386 |
| - def unique(self, a): |
2387 |
| - return cp.unique(a) |
| 2439 | + def unique(self, a, return_inverse=False): |
| 2440 | + return cp.unique(a, return_inverse=return_inverse) |
2388 | 2441 |
|
2389 | 2442 | def logsumexp(self, a, axis=None):
|
2390 | 2443 | # Taken from
|
@@ -2582,7 +2635,7 @@ def matmul(self, a, b):
|
2582 | 2635 |
|
2583 | 2636 | if cp:
|
2584 | 2637 | # Only register cp backend if it is installed
|
2585 |
| - register_backend(CupyBackend()) |
| 2638 | + _register_backend_implementation(CupyBackend) |
2586 | 2639 |
|
2587 | 2640 |
|
2588 | 2641 | class TensorflowBackend(Backend):
|
@@ -2717,8 +2770,8 @@ def sqrt(self, a):
|
2717 | 2770 | def power(self, a, exponents):
|
2718 | 2771 | return tnp.power(a, exponents)
|
2719 | 2772 |
|
2720 |
| - def norm(self, a): |
2721 |
| - return tf.math.reduce_euclidean_norm(a) |
| 2773 | + def norm(self, a, axis=None, keepdims=False): |
| 2774 | + return tf.math.reduce_euclidean_norm(a, axis=axis, keepdims=keepdims) |
2722 | 2775 |
|
2723 | 2776 | def any(self, a):
|
2724 | 2777 | return tnp.any(a)
|
@@ -2790,8 +2843,15 @@ def meshgrid(self, a, b):
|
2790 | 2843 | def diag(self, a, k=0):
|
2791 | 2844 | return tnp.diag(a, k)
|
2792 | 2845 |
|
2793 |
| - def unique(self, a): |
2794 |
| - return tf.sort(tf.unique(tf.reshape(a, [-1]))[0]) |
| 2846 | + def unique(self, a, return_inverse=False): |
| 2847 | + y, idx = tf.unique(tf.reshape(a, [-1])) |
| 2848 | + sort_idx = tf.argsort(y) |
| 2849 | + y_prime = tf.gather(y, sort_idx) |
| 2850 | + if return_inverse: |
| 2851 | + inv_sort_idx = tf.math.invert_permutation(sort_idx) |
| 2852 | + return y_prime, tf.gather(inv_sort_idx, idx) |
| 2853 | + else: |
| 2854 | + return y_prime |
2795 | 2855 |
|
2796 | 2856 | def logsumexp(self, a, axis=None):
|
2797 | 2857 | return tf.math.reduce_logsumexp(a, axis=axis)
|
@@ -2995,4 +3055,4 @@ def matmul(self, a, b):
|
2995 | 3055 |
|
2996 | 3056 | if tf:
|
2997 | 3057 | # Only register tensorflow backend if it is installed
|
2998 |
| - register_backend(TensorflowBackend()) |
| 3058 | + _register_backend_implementation(TensorflowBackend) |
0 commit comments