Skip to content

Commit 98a880e

Browse files
authored
Merge branch 'master' into semirelaxed_gromov
2 parents 7c758e3 + 058d275 commit 98a880e

25 files changed

+1150
-86
lines changed

RELEASES.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
- Added semi-relaxed (Fused) Gromov-Wasserstein solvers + examples (PR #401)
88
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
99
- Added Free Support Sinkhorn Barycenter + example (PR #387)
10+
- New API for OT solver using function `ot.solve` (PR #388)
11+
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
12+
1013

1114
#### Closed issues
1215

@@ -26,6 +29,14 @@ roughly 2^31) (PR #381)
2629
- Fixed an issue where a pytorch example would throw an error if executed on a GPU (Issue #389, PR #391)
2730
- Added a work-around for scipy's bug, where you cannot compute the Hamming distance with a "None" weight attribute. (Issue #400, PR #402)
2831
- Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402)
32+
- Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409)
33+
- Fixed weak optimal transport docstring (Issue #404, PR #410)
34+
- Fixed error with parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
35+
PR #413)
36+
- Fixed an issue about `warn` parameter in `sinkhorn2` (PR #417)
37+
- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls
38+
that explicitly specified `stopThr=1e-9` (Issue #421, PR #422).
39+
- Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425)
2940

3041

3142
## 0.8.2

examples/gromov/plot_gromov_barycenter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ def im2mat(img):
110110
if shapes[nb][i, j] < 0.95:
111111
xs[nb].append([j, 8 - i])
112112

113-
xs = np.array([np.array(xs[0]), np.array(xs[1]),
114-
np.array(xs[2]), np.array(xs[3])])
113+
xs = [np.array(xs[s]) for s in range(S)]
115114

116115
##############################################################################
117116
# Barycenter computation

ot/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from . import regpath
3535
from . import weak
3636
from . import factored
37+
from . import solvers
3738

3839
# OT functions
3940
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -46,7 +47,7 @@
4647
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
4748
from .weak import weak_optimal_transport
4849
from .factored import factored_optimal_transport
49-
50+
from .solvers import solve
5051

5152
# utils functions
5253
from .utils import dist, unif, tic, toc, toq
@@ -61,5 +62,5 @@
6162
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
6263
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
6364
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
64-
'factored_optimal_transport',
65-
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
65+
'factored_optimal_transport', 'solve',
66+
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers']

ot/backend.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,21 @@ def sqrtm(self, a):
854854
"""
855855
raise NotImplementedError()
856856

857+
def kl_div(self, p, q, eps=1e-16):
858+
r"""
859+
Computes the Kullback-Leibler divergence.
860+
861+
This function follows the api from :any:`scipy.stats.entropy`.
862+
863+
Parameter eps is used to avoid numerical errors and is added in the log.
864+
865+
.. math::
866+
KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
867+
868+
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
869+
"""
870+
raise NotImplementedError()
871+
857872
def isfinite(self, a):
858873
r"""
859874
Tests element-wise for finiteness (not infinity and not Not a Number).
@@ -1158,6 +1173,9 @@ def inv(self, a):
11581173
def sqrtm(self, a):
11591174
return scipy.linalg.sqrtm(a)
11601175

1176+
def kl_div(self, p, q, eps=1e-16):
1177+
return np.sum(p * np.log(p / q + eps))
1178+
11611179
def isfinite(self, a):
11621180
return np.isfinite(a)
11631181

@@ -1481,6 +1499,9 @@ def sqrtm(self, a):
14811499
L, V = jnp.linalg.eigh(a)
14821500
return (V * jnp.sqrt(L)[None, :]) @ V.T
14831501

1502+
def kl_div(self, p, q, eps=1e-16):
1503+
return jnp.sum(p * jnp.log(p / q + eps))
1504+
14841505
def isfinite(self, a):
14851506
return jnp.isfinite(a)
14861507

@@ -1901,6 +1922,9 @@ def sqrtm(self, a):
19011922
L, V = torch.linalg.eigh(a)
19021923
return (V * torch.sqrt(L)[None, :]) @ V.T
19031924

1925+
def kl_div(self, p, q, eps=1e-16):
1926+
return torch.sum(p * torch.log(p / q + eps))
1927+
19041928
def isfinite(self, a):
19051929
return torch.isfinite(a)
19061930

@@ -2248,6 +2272,9 @@ def sqrtm(self, a):
22482272
L, V = cp.linalg.eigh(a)
22492273
return (V * self.sqrt(L)[None, :]) @ V.T
22502274

2275+
def kl_div(self, p, q, eps=1e-16):
2276+
return cp.sum(p * cp.log(p / q + eps))
2277+
22512278
def isfinite(self, a):
22522279
return cp.isfinite(a)
22532280

@@ -2608,6 +2635,9 @@ def inv(self, a):
26082635
def sqrtm(self, a):
26092636
return tf.linalg.sqrtm(a)
26102637

2638+
def kl_div(self, p, q, eps=1e-16):
2639+
return tnp.sum(p * tnp.log(p / q + eps))
2640+
26112641
def isfinite(self, a):
26122642
return tnp.isfinite(a)
26132643

ot/bregman.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
207207
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
208208
weights (histograms, both sum to 1)
209209
210+
and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
211+
the entropic contribution).
212+
210213
.. note:: This function is backend-compatible and will work on arrays
211214
from all compatible backends.
212215
@@ -320,15 +323,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
320323
if len(b.shape) < 2:
321324
if method.lower() == 'sinkhorn':
322325
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
323-
stopThr=stopThr, verbose=verbose, log=log,
326+
stopThr=stopThr, verbose=verbose,
327+
log=log, warn=warn,
324328
**kwargs)
325329
elif method.lower() == 'sinkhorn_log':
326330
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
327-
stopThr=stopThr, verbose=verbose, log=log,
331+
stopThr=stopThr, verbose=verbose,
332+
log=log, warn=warn,
328333
**kwargs)
329334
elif method.lower() == 'sinkhorn_stabilized':
330335
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
331-
stopThr=stopThr, verbose=verbose, log=log,
336+
stopThr=stopThr, verbose=verbose,
337+
log=log, warn=warn,
332338
**kwargs)
333339
else:
334340
raise ValueError("Unknown method '%s'." % method)
@@ -341,15 +347,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
341347

342348
if method.lower() == 'sinkhorn':
343349
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
344-
stopThr=stopThr, verbose=verbose, log=log,
350+
stopThr=stopThr, verbose=verbose,
351+
log=log, warn=warn,
345352
**kwargs)
346353
elif method.lower() == 'sinkhorn_log':
347354
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
348-
stopThr=stopThr, verbose=verbose, log=log,
355+
stopThr=stopThr, verbose=verbose,
356+
log=log, warn=warn,
349357
**kwargs)
350358
elif method.lower() == 'sinkhorn_stabilized':
351359
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
352-
stopThr=stopThr, verbose=verbose, log=log,
360+
stopThr=stopThr, verbose=verbose,
361+
log=log, warn=warn,
353362
**kwargs)
354363
else:
355364
raise ValueError("Unknown method '%s'." % method)
@@ -1278,7 +1287,7 @@ def get_reg(n): # exponential decreasing
12781287
regi = get_reg(ii)
12791288

12801289
G, logi = sinkhorn_stabilized(a, b, M, regi,
1281-
numItermax=numInnerItermax, stopThr=1e-9,
1290+
numItermax=numInnerItermax, stopThr=stopThr,
12821291
warmstart=(alpha, beta), verbose=False,
12831292
print_period=20, tau=tau, log=True)
12841293

@@ -3059,6 +3068,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
30593068
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
30603069
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
30613070
3071+
and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
3072+
the entropic contribution).
3073+
30623074
30633075
Parameters
30643076
----------
@@ -3237,6 +3249,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
32373249
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
32383250
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
32393251
3252+
and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F -(\langle \gamma^*_a, \mathbf{M_a} \rangle_F + \langle
3253+
\gamma^*_b , \mathbf{M_b} \rangle_F)/2`.
3254+
3255+
.. note: The current implementation does not account for the entropic contributions and thus differs from the
3256+
Sinkhorn divergence as introduced in the literature. The possibility to account for the entropic contributions
3257+
will be provided in a future release.
3258+
32403259
32413260
Parameters
32423261
----------
@@ -3293,17 +3312,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
32933312
if log:
32943313
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
32953314
numIterMax=numIterMax,
3296-
stopThr=1e-9, verbose=verbose,
3315+
stopThr=stopThr, verbose=verbose,
32973316
log=log, warn=warn, **kwargs)
32983317

32993318
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
33003319
numIterMax=numIterMax,
3301-
stopThr=1e-9, verbose=verbose,
3320+
stopThr=stopThr, verbose=verbose,
33023321
log=log, warn=warn, **kwargs)
33033322

33043323
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
33053324
numIterMax=numIterMax,
3306-
stopThr=1e-9, verbose=verbose,
3325+
stopThr=stopThr, verbose=verbose,
33073326
log=log, warn=warn, **kwargs)
33083327

33093328
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
@@ -3320,17 +3339,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33203339

33213340
else:
33223341
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
3323-
numIterMax=numIterMax, stopThr=1e-9,
3342+
numIterMax=numIterMax, stopThr=stopThr,
33243343
verbose=verbose, log=log,
33253344
warn=warn, **kwargs)
33263345

33273346
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
3328-
numIterMax=numIterMax, stopThr=1e-9,
3347+
numIterMax=numIterMax, stopThr=stopThr,
33293348
verbose=verbose, log=log,
33303349
warn=warn, **kwargs)
33313350

33323351
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
3333-
numIterMax=numIterMax, stopThr=1e-9,
3352+
numIterMax=numIterMax, stopThr=stopThr,
33343353
verbose=verbose, log=log,
33353354
warn=warn, **kwargs)
33363355

ot/da.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
126126
W = nx.zeros(M.shape, type_as=M)
127127
for cpt in range(numItermax):
128128
Mreg = M + eta * W
129-
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
130-
stopThr=stopInnerThr)
129+
if log:
130+
transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
131+
stopThr=stopInnerThr, log=True)
132+
else:
133+
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
134+
stopThr=stopInnerThr)
131135
# the transport has been computed. Check if classes are really
132136
# separated
133137
W = nx.ones(M.shape, type_as=M)
@@ -136,7 +140,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
136140
majs = p * ((majs + epsilon) ** (p - 1))
137141
W[indices_labels[i]] = majs
138142

139-
return transp
143+
if log:
144+
return transp, log
145+
else:
146+
return transp
140147

141148

142149
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,

ot/helpers/pre_build_helpers.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,14 @@
44
import sys
55
import glob
66
import tempfile
7-
import setuptools # noqa
87
import subprocess
98

10-
from distutils.dist import Distribution
11-
from distutils.sysconfig import customize_compiler
12-
from numpy.distutils.ccompiler import new_compiler
13-
from numpy.distutils.command.config_compiler import config_cc
9+
from setuptools.command.build_ext import customize_compiler, new_compiler
1410

1511

1612
def _get_compiler():
17-
"""Get a compiler equivalent to the one that will be used to build POT
18-
Handles compiler specified as follows:
19-
- python setup.py build_ext --compiler=<compiler>
20-
- CC=<compiler> python setup.py build_ext
21-
"""
22-
dist = Distribution({'script_name': os.path.basename(sys.argv[0]),
23-
'script_args': sys.argv[1:],
24-
'cmdclass': {'config_cc': config_cc}})
25-
26-
cmd_opts = dist.command_options.get('build_ext')
27-
if cmd_opts is not None and 'compiler' in cmd_opts:
28-
compiler = cmd_opts['compiler'][1]
29-
else:
30-
compiler = None
31-
32-
ccompiler = new_compiler(compiler=compiler)
13+
ccompiler = new_compiler()
3314
customize_compiler(ccompiler)
34-
3515
return ccompiler
3616

3717

ot/lp/network_simplex_simple_omp.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
//#include "core.h"
6868
//#include "lmath.h"
6969

70-
#ifdef OMP
70+
#ifdef _OPENMP
7171
#include <omp.h>
7272
#endif
7373
#include <cmath>
@@ -254,7 +254,7 @@ namespace lemon_omp {
254254
// Reset data structures
255255
reset();
256256
max_iter = maxiters;
257-
#ifdef OMP
257+
#ifdef _OPENMP
258258
if (max_threads < 0) {
259259
max_threads = omp_get_max_threads();
260260
}
@@ -513,7 +513,7 @@ namespace lemon_omp {
513513
int j;
514514
#pragma omp parallel
515515
{
516-
#ifdef OMP
516+
#ifdef _OPENMP
517517
int t = omp_get_thread_num();
518518
#else
519519
int t = 0;

0 commit comments

Comments
 (0)