From 140da40577e86b4546bbb033a43739820069fd11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 25 Oct 2021 16:54:10 +0200 Subject: [PATCH 01/17] adda sinkhorn log and working sinkhorn2 function --- ot/bregman.py | 177 ++++++++++++++++++++++++++++++++++++++++++- ot/dr.py | 4 +- test/test_bregman.py | 2 + test/test_helpers.py | 4 +- 4 files changed, 181 insertions(+), 6 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index b59ee1b4c..b7dd37b29 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -24,7 +24,7 @@ from .backend import get_backend -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, +def sinkhorn(a, b, M, reg, method='sinkhorn_log', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -134,6 +134,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log) @@ -150,7 +154,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, raise ValueError("Unknown method '%s'." % method) -def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, +def sinkhorn2(a, b, M, reg, method='sinkhorn_log', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -265,6 +269,10 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, @@ -438,6 +446,171 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) +def sinkhorn_log(a, b, M, reg, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the entropic regularization optimal transport problem in log space + and return the OT matrix + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` + + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + .. _references-sinkhorn-log: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + + if len(a) == 0: + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) + if len(b) == 0: + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) + + # init data + dim_a = len(a) + dim_b = len(b) + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if n_hists: + u = nx.zeros((dim_a, 1, n_hists), type_as=M) / dim_a + v = nx.zeros((dim_b, 1, n_hists), type_as=M) / dim_b + else: + u = nx.zeros(dim_a, type_as=M) / dim_a + v = nx.zeros(dim_b, type_as=M) / dim_b + + def get_logT(M, u, v): + if n_hists: + return (M - u - v) / (-reg) + else: + return (M - u[:, None] - v[None, :]) / (-reg) + loga = nx.log(a) + logb = nx.log(b) + + cpt = 0 + err = 1 + while (err > stopThr and cpt < numItermax): + + u = reg * (loga - nx.logsumexp(get_logT(M, u, v), 1)) + u + v = reg * (logb - nx.logsumexp(get_logT(M, u, v), 0)) + v + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + if n_hists: + tmp2 = nx.sum(nx.exp(get_logT(M, u, v)), 1, keepdims=True) + else: + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + tmp2 = nx.sum(nx.exp(get_logT(M, u, v)), 1) + err = nx.norm(tmp2 - b) # violation of marginal + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + + if log: + log['log_u'] = u + log['log_v'] = v + log['u'] = nx.exp(u / reg) + log['v'] = nx.exp(v / reg) + + if n_hists: # return only loss + res = nx.sum(nx.exp(get_logT(M, u, v)) * M, (0, 1)) + if n_hists==1: + res=res[0] + if log: + return res, log + else: + return res + + else: # return OT matrix + + if log: + return nx.exp(get_logT(M, u, v)), log + else: + return nx.exp(get_logT(M, u, v)) + + def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False): r""" diff --git a/ot/dr.py b/ot/dr.py index 64588cf8f..de39662f2 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -209,11 +209,11 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh .. math:: \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) - + - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold - :math:`H(\pi)` is entropy regularizer - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively - + Parameters ---------- X : ndarray, shape (n, d) diff --git a/test/test_bregman.py b/test/test_bregman.py index 942cb6d0d..865a0115a 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -147,6 +147,7 @@ def test_sinkhorn_variants(nx): Mb = nx.from_numpy(M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10) G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( @@ -155,6 +156,7 @@ def test_sinkhorn_variants(nx): # check values np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) diff --git a/test/test_helpers.py b/test/test_helpers.py index 8bd0015e4..cc4c90eaf 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -9,8 +9,8 @@ sys.path.append(os.path.join("ot", "helpers")) -from openmp_helpers import get_openmp_flag, check_openmp_support # noqa -from pre_build_helpers import _get_compiler, compile_test_program # noqa +from openmp_helpers import get_openmp_flag, check_openmp_support # noqa +from pre_build_helpers import _get_compiler, compile_test_program # noqa def test_helpers(): From ff6e356c8b66b4d21ef8d2b31f1d62c3e08307d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 25 Oct 2021 17:11:25 +0200 Subject: [PATCH 02/17] more tests pass --- ot/bregman.py | 7 +++---- ot/utils.py | 4 ++-- test/test_bregman.py | 8 ++++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index b7dd37b29..d8251a74b 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -596,8 +596,8 @@ def get_logT(M, u, v): if n_hists: # return only loss res = nx.sum(nx.exp(get_logT(M, u, v)) * M, (0, 1)) - if n_hists==1: - res=res[0] + if n_hists == 1: + res = res[0] if log: return res, log else: @@ -2054,8 +2054,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return (f, g) else: - M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) - M = nx.from_numpy(M, type_as=a) + M = dist(X_s, X_t, metric=metric) if log: pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) return pi, log diff --git a/ot/utils.py b/ot/utils.py index 6a782e689..0608aee01 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -183,7 +183,7 @@ def euclidean_distances(X, Y, squared=False): return c -def dist(x1, x2=None, metric='sqeuclidean'): +def dist(x1, x2=None, metric='sqeuclidean', p=2): """Compute distance between samples in x1 and x2 .. note:: This function is backend-compatible and will work on arrays @@ -222,7 +222,7 @@ def dist(x1, x2=None, metric='sqeuclidean'): if not get_backend(x1, x2).__name__ == 'numpy': raise NotImplementedError() else: - return cdist(x1, x2, metric=metric) + return cdist(x1, x2, metric=metric, p=p) def dist0(n, method='lin_square'): diff --git a/test/test_bregman.py b/test/test_bregman.py index 865a0115a..d0597ee59 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -328,10 +328,10 @@ def test_empirical_sinkhorn(nx): a = ot.unif(n) b = ot.unif(n) - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n), (n, 1)) + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='minkowski') + M_m = ot.dist(X_s, X_t, metric='euclidean') ab = nx.from_numpy(a) bb = nx.from_numpy(b) @@ -348,7 +348,7 @@ def test_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski')) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) From c460f0ed1b369d4bdda707209a64b55ce9b7c4e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 25 Oct 2021 17:12:23 +0200 Subject: [PATCH 03/17] more tests pass --- test/test_bregman.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index d0597ee59..5b8319163 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -380,7 +380,7 @@ def test_lazy_empirical_sinkhorn(nx): X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='minkowski') + M_m = ot.dist(X_s, X_t, metric='euclidean') ab = nx.from_numpy(a) bb = nx.from_numpy(b) @@ -400,7 +400,7 @@ def test_lazy_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) From 59f2672fe52472be41853eb3a3ea95f367fa188d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 25 Oct 2021 17:35:12 +0200 Subject: [PATCH 04/17] it works but not by default yet --- ot/bregman.py | 6 +++--- test/test_bregman.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index d8251a74b..2c6812afc 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -24,7 +24,7 @@ from .backend import get_backend -def sinkhorn(a, b, M, reg, method='sinkhorn_log', numItermax=1000, +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -154,7 +154,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn_log', numItermax=1000, raise ValueError("Unknown method '%s'." % method) -def sinkhorn2(a, b, M, reg, method='sinkhorn_log', numItermax=1000, +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -595,7 +595,7 @@ def get_logT(M, u, v): log['v'] = nx.exp(v / reg) if n_hists: # return only loss - res = nx.sum(nx.exp(get_logT(M, u, v)) * M, (0, 1)) + res = nx.sum(nx.exp(get_logT(M, u, v)) * M[:, :, None], (0, 1)) if n_hists == 1: res = res[0] if log: diff --git a/test/test_bregman.py b/test/test_bregman.py index 5b8319163..f8a4897ae 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -147,7 +147,7 @@ def test_sinkhorn_variants(nx): Mb = nx.from_numpy(M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( @@ -174,6 +174,7 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) @@ -181,6 +182,7 @@ def test_sinkhorn_variants_log(): # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) print(G0, G_green) From 1c105ad5d563f5c9c524c5208305d4cc2a8c8516 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 25 Oct 2021 17:38:58 +0200 Subject: [PATCH 05/17] remove warningd --- ot/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/optim.py b/ot/optim.py index 6822e4eba..34cbb178c 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -20,7 +20,7 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99): - """ + r""" Armijo linesearch function that works with matrices Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the @@ -447,7 +447,7 @@ def cost(G): def solve_1d_linesearch_quad(a, b, c): - """ + r""" For any convex or non-convex 1d quadratic function `f`, solve the following problem: .. math:: From 31d3666ff9338d85a735e6507fecc1815e7b118b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 25 Oct 2021 18:25:46 +0200 Subject: [PATCH 06/17] update circleci doc --- .circleci/config.yml | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index e4c71dde1..379394aea 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -4,7 +4,7 @@ version: 2 jobs: build_docs: docker: - - image: circleci/python:3.7-stretch + - image: cimg/python:3.9 steps: - checkout - run: @@ -34,18 +34,6 @@ jobs: - data-cache-0 - pip-cache - - run: - name: Spin up Xvfb - command: | - /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset; - - # https://github.com/ContinuumIO/anaconda-issues/issues/9190#issuecomment-386508136 - # https://github.com/golemfactory/golem/issues/1019 - - run: - name: Fix libgcc_s.so.1 pthread_cancel bug - command: | - sudo apt-get install qt5-default - - run: name: Get Python running command: | From b88ed3946699c86c4697f7088a3d7576ae51ff16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 26 Oct 2021 08:15:04 +0200 Subject: [PATCH 07/17] update circleci doc --- ot/bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/bregman.py b/ot/bregman.py index 2c6812afc..1408c7aaf 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -24,7 +24,7 @@ from .backend import get_backend -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, +def sinkhorn(a, b, M, reg, method='sinkhorn_log', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix From 57cb7a37f01b6852c98847db9c02dcc267a19d7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 26 Oct 2021 11:46:03 +0200 Subject: [PATCH 08/17] new sinkhorn implemeted but not by default --- ot/bregman.py | 174 ++++++++++++++++++++++++++----------------- ot/gromov.py | 4 +- test/test_bregman.py | 21 ++++++ test/test_gromov.py | 10 +-- 4 files changed, 133 insertions(+), 76 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 1408c7aaf..359aadb9c 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -24,7 +24,7 @@ from .backend import get_backend -def sinkhorn(a, b, M, reg, method='sinkhorn_log', numItermax=1000, +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -262,23 +262,44 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, """ b = list_to_array(b) + nx = get_backend(M, a, b) + if len(b.shape) < 2: - b = b[:, None] + if method.lower() == 'sinkhorn': + res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + if log: + return nx.sum(M * res[0]), res[1] + else: + return nx.sum(M * res) - if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) else: - raise ValueError("Unknown method '%s'." % method) + + if method.lower() == 'sinkhorn': + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) def sinkhorn_knopp(a, b, M, reg, numItermax=1000, @@ -369,7 +390,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # init data dim_a = len(a) - dim_b = len(b) + dim_b = b.shape[0] if len(b.shape) > 1: n_hists = b.shape[1] @@ -535,80 +556,95 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, # init data dim_a = len(a) - dim_b = len(b) + dim_b = b.shape[0] if len(b.shape) > 1: n_hists = b.shape[1] else: n_hists = 0 - if log: - log = {'err': []} + if n_hists: # we do not want to use tensors sor we do a loop - # we assume that no distances are null except those of the diagonal of - # distances - if n_hists: - u = nx.zeros((dim_a, 1, n_hists), type_as=M) / dim_a - v = nx.zeros((dim_b, 1, n_hists), type_as=M) / dim_b - else: - u = nx.zeros(dim_a, type_as=M) / dim_a - v = nx.zeros(dim_b, type_as=M) / dim_b + lst_loss = [] + lst_u = [] + lst_v = [] - def get_logT(M, u, v): - if n_hists: - return (M - u - v) / (-reg) + for k in range(n_hists): + res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + if log: + lst_loss.append(nx.sum(M * res[0])) + lst_u.append(res[1]['log_u']) + lst_v.append(res[1]['log_v']) + else: + lst_loss.append(nx.sum(M * res)) + res = nx.stack(lst_loss) + if log: + log = {'log_u': nx.stack(lst_u, 1), + 'log_v': nx.stack(lst_v, 1), } + log['u'] = nx.exp(log['log_u']) + log['v'] = nx.exp(log['log_v']) + return res, log else: - return (M - u[:, None] - v[None, :]) / (-reg) - loga = nx.log(a) - logb = nx.log(b) + return res - cpt = 0 - err = 1 - while (err > stopThr and cpt < numItermax): + else: - u = reg * (loga - nx.logsumexp(get_logT(M, u, v), 1)) + u - v = reg * (logb - nx.logsumexp(get_logT(M, u, v), 0)) + v + if log: + log = {'err': []} - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations + Mr = M / (-reg) + + # we assume that no distances are null except those of the diagonal of + # distances + + u = nx.zeros(dim_a, type_as=M) + v = nx.zeros(dim_b, type_as=M) + + def get_logT(u, v): if n_hists: - tmp2 = nx.sum(nx.exp(get_logT(M, u, v)), 1, keepdims=True) + return Mr[:, :, None] + u + v else: - # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - tmp2 = nx.sum(nx.exp(get_logT(M, u, v)), 1) - err = nx.norm(tmp2 - b) # violation of marginal - if log: - log['err'].append(err) + return Mr + u[:, None] + v[None, :] - if verbose: - if cpt % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + loga = nx.log(a) + logb = nx.log(b) - if log: - log['log_u'] = u - log['log_v'] = v - log['u'] = nx.exp(u / reg) - log['v'] = nx.exp(v / reg) + cpt = 0 + err = 1 + while (err > stopThr and cpt < numItermax): - if n_hists: # return only loss - res = nx.sum(nx.exp(get_logT(M, u, v)) * M[:, :, None], (0, 1)) - if n_hists == 1: - res = res[0] - if log: - return res, log - else: - return res + v = logb - nx.logsumexp(Mr + u[:, None], 0) + u = loga - nx.logsumexp(Mr + v[None, :], 1) - else: # return OT matrix + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0) + err = nx.norm(tmp2 - b) # violation of marginal + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 if log: - return nx.exp(get_logT(M, u, v)), log + log['log_u'] = u + log['log_v'] = v + log['u'] = nx.exp(u) + log['v'] = nx.exp(v) + + return nx.exp(get_logT(u, v)), log + else: - return nx.exp(get_logT(M, u, v)) + return nx.exp(get_logT(u, v)) def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, diff --git a/ot/gromov.py b/ot/gromov.py index 85b1549d4..33b445326 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -1030,7 +1030,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, # compute the gradient tens = gwggrad(constC, hC1, hC2, T) - T = sinkhorn(p, q, tens, epsilon) + T = sinkhorn(p, q, tens, epsilon, method='sinkhorn') if cpt % 10 == 0: # we can speed up the process by checking for the error only all @@ -1204,7 +1204,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Cprev = C T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, - max_iter, 1e-5, verbose, log) for s in range(S)] + max_iter, 1e-4, verbose, log) for s in range(S)] if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) diff --git a/test/test_bregman.py b/test/test_bregman.py index f8a4897ae..14ffc22ae 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -32,6 +32,27 @@ def test_sinkhorn(): u, G.sum(0), atol=1e-05) # cf convergence sinkhorn +def test_sinkhorn_multi_b(): + # test sinkhorn + n = 10 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True) + + loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)] + # check constraints + np.testing.assert_allclose( + loss0, loss, atol=1e-05) # cf convergence sinkhorn + + def test_sinkhorn_backends(nx): n_samples = 100 n_features = 2 diff --git a/test/test_gromov.py b/test/test_gromov.py index 19d61b104..0242d7247 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -180,8 +180,8 @@ def loss(x, y): def test_gromov_barycenter(): - ns = 50 - nt = 60 + ns = 10 + nt = 20 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -208,8 +208,8 @@ def test_gromov_barycenter(): @pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(): - ns = 20 - nt = 30 + ns = 10 + nt = 20 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -222,7 +222,7 @@ def test_gromov_entropic_barycenter(): [ot.unif(ns), ot.unif(nt) ], ot.unif(n_samples), [.5, .5], 'square_loss', 1e-3, - max_iter=50, tol=1e-5, + max_iter=50, tol=1e-3, verbose=True) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) From 4ef9bf3e92a014e56d3b0c02c89b9dee7a4b72dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 26 Oct 2021 11:52:48 +0200 Subject: [PATCH 09/17] better --- ot/bregman.py | 2 +- test/test_bregman.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 359aadb9c..877b8f58e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -261,7 +261,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, """ - b = list_to_array(b) + M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) if len(b.shape) < 2: diff --git a/test/test_bregman.py b/test/test_bregman.py index 14ffc22ae..bbadb27e1 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -50,7 +50,7 @@ def test_sinkhorn_multi_b(): loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)] # check constraints np.testing.assert_allclose( - loss0, loss, atol=1e-05) # cf convergence sinkhorn + loss0, loss, atol=1e-06) # cf convergence sinkhorn def test_sinkhorn_backends(nx): From f82c5019a7a875b405062b8f3b3829d7f0bf83ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 26 Oct 2021 12:26:44 +0200 Subject: [PATCH 10/17] doctest pass --- ot/bregman.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 877b8f58e..dfafc501a 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -234,7 +234,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn2(a, b, M, 1) - array([0.26894142]) + 0.26894142136999516 .. _references-sinkhorn2: @@ -2310,7 +2310,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS - array([1.499...]) + 1.49988717604905 References From f7698510d8ded16db1917b527a2573684300b8b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 26 Oct 2021 15:15:16 +0200 Subject: [PATCH 11/17] test doctest --- ot/bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/bregman.py b/ot/bregman.py index dfafc501a..431317237 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -2310,7 +2310,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS - 1.49988717604905 + 1.499887176049052 References From b1c266e30c4ab192d9e6c53a2ee7a3d59bfbfed0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 26 Oct 2021 16:58:45 +0200 Subject: [PATCH 12/17] new test utils --- test/test_bregman.py | 30 +++++++++++++++++++++++++++++- test/test_utils.py | 12 ++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index bbadb27e1..f4f2e7dc7 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -181,7 +181,35 @@ def test_sinkhorn_variants(nx): np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) - print(G0, G_green) + + +@pytest.skip_backend("jax") +def test_sinkhorn_variants_multi_b(nx): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub = nx.from_numpy(u) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) def test_sinkhorn_variants_log(): diff --git a/test/test_utils.py b/test/test_utils.py index 60ad5d3de..c70f9cf68 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,7 @@ import ot import numpy as np import sys +import pytest def test_proj_simplex(nx): @@ -108,6 +109,8 @@ def test_dist(): D2 = ot.dist(x, x) D3 = ot.dist(x) + D4 = ot.dist(x, x, metric='minkowski', p=0.5) + # dist shoul return squared euclidean np.testing.assert_allclose(D, D2, atol=1e-14) np.testing.assert_allclose(D, D3, atol=1e-14) @@ -220,6 +223,12 @@ def fun2(): class Class(): pass + with pytest.warns(DeprecationWarning): + fun() + + with pytest.warns(DeprecationWarning): + cl = Class() + if sys.version_info < (3, 5): print('Not tested') else: @@ -250,4 +259,7 @@ def __init__(self, first='spam', second='eggs'): params['first'] = 'spam again' cl.set_params(**params) + with pytest.raises(ValueError): + cl.set_params(bibi=10) + assert cl.first == 'spam again' From 43c92f813a47abb572fd0a0fc64bcf9655109c22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 26 Oct 2021 17:01:09 +0200 Subject: [PATCH 13/17] remove pep8 errors --- test/test_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index c70f9cf68..d2dc2704e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -111,6 +111,8 @@ def test_dist(): D4 = ot.dist(x, x, metric='minkowski', p=0.5) + assert D[0, 1] == D[1, 0] + # dist shoul return squared euclidean np.testing.assert_allclose(D, D2, atol=1e-14) np.testing.assert_allclose(D, D3, atol=1e-14) @@ -228,6 +230,7 @@ class Class(): with pytest.warns(DeprecationWarning): cl = Class() + print(cl) if sys.version_info < (3, 5): print('Not tested') From d885c5f3073f168ededa553ba5762d6f3612b027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 26 Oct 2021 17:02:28 +0200 Subject: [PATCH 14/17] remove pep8 errors --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index d2dc2704e..0650ce2a6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -111,7 +111,7 @@ def test_dist(): D4 = ot.dist(x, x, metric='minkowski', p=0.5) - assert D[0, 1] == D[1, 0] + assert D4[0, 1] == D4[1, 0] # dist shoul return squared euclidean np.testing.assert_allclose(D, D2, atol=1e-14) From 8870aa405c7182107e4c5245a3aa0a4823876493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 26 Oct 2021 17:21:07 +0200 Subject: [PATCH 15/17] doc new implementtaion with log --- ot/bregman.py | 12 ++++++++---- test/test_bregman.py | 26 +++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 431317237..3518bcbca 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -56,7 +56,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, By default and when using a regularization parameter that is not too small the default sinkhorn solver should be enough. If you need to use a small regularization to get sharper OT matrices, you should use the - :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + :py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical errors. This last solver can be very slow in practice and might not even converge to a reasonable OT matrix in a finite time. This is why :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value @@ -64,7 +64,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, solutions. Note that the greedy version of the sinkhorn :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation we strongly reconmend the :any:`ot.bregman.sinkhorn_log` solver + that will no need to check for numerical problems. Parameters @@ -186,7 +188,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, By default and when using a regularization parameter that is not too small the default sinkhorn solver should be enough. If you need to use a small regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + :any:`ot.bregman.sinkhorn_log` solver that will avoid numerical errors. This last solver can be very slow in practice and might not even converge to a reasonable OT matrix in a finite time. This is why :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value @@ -194,7 +196,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, solutions. Note that the greedy version of the sinkhorn :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation we strongly reconmend the :any:`ot.bregman.sinkhorn_log` solver + that will no need to check for numerical problems. Parameters ---------- diff --git a/test/test_bregman.py b/test/test_bregman.py index f4f2e7dc7..0bdd931a6 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -186,7 +186,7 @@ def test_sinkhorn_variants(nx): @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn - n = 100 + n = 50 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -214,7 +214,7 @@ def test_sinkhorn_variants_multi_b(nx): def test_sinkhorn_variants_log(): # test sinkhorn - n = 100 + n = 50 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -234,7 +234,27 @@ def test_sinkhorn_variants_log(): np.testing.assert_allclose(G0, Gl, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) np.testing.assert_allclose(G0, G_green, atol=1e-5) - print(G0, G_green) + + +def test_sinkhorn_variants_log_multib(): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Gl, atol=1e-05) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) From 372a8b166410d9d29a2d0f891c59f94bbffee1e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 26 Oct 2021 17:32:24 +0200 Subject: [PATCH 16/17] test sinkhorn 2 --- test/test_bregman.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/test_bregman.py b/test/test_bregman.py index 0bdd931a6..c1120ba9c 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -212,6 +212,35 @@ def test_sinkhorn_variants_multi_b(nx): np.testing.assert_allclose(G0, Gs, atol=1e-05) +@pytest.skip_backend("jax") +def test_sinkhorn2_variants_multi_b(nx): + # test sinkhorn + n = 50 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + b = rng.rand(n, 3) + b = b / np.sum(b, 0, keepdims=True) + + M = ot.dist(x, x) + + ub = nx.from_numpy(u) + bb = nx.from_numpy(b) + Mb = nx.from_numpy(M) + + G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + + # check values + np.testing.assert_allclose(G, G0, atol=1e-05) + np.testing.assert_allclose(G, Gl, atol=1e-05) + np.testing.assert_allclose(G0, Gs, atol=1e-05) + + def test_sinkhorn_variants_log(): # test sinkhorn n = 50 From a6b297aadeba0ff4e1ba00e654eea8ec7f5b47cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 27 Oct 2021 08:25:46 +0200 Subject: [PATCH 17/17] doc for log implementation --- README.md | 4 +++- docs/source/quickstart.rst | 10 +++++++++- ot/bregman.py | 36 +++++++++++++++++++++++++----------- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 266d847c2..ffad0bd0f 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ POT provides the following generic OT solvers (links to examples): * [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] . * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. -* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). +* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. @@ -290,3 +290,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). [33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 + +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index fd046a162..232df7be3 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -358,6 +358,11 @@ More details about the algorithms used are given in the following note. + :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the classic algorithm [2]_. + + :code:`method='sinkhorn_log'` calls :any:`ot.bregman.sinkhorn_log` the + sinkhorn algorithm in log space [2]_ that is more stable but can be + slower in numpy since `logsumexp` is not implmemented in parallel. + It is the recommended solver for applications that requires + differentiability with a small number of iterations. + :code:`method='sinkhorn_stabilized'` calls :any:`ot.bregman.sinkhorn_stabilized` the log stabilized version of the algorithm [9]_. + :code:`method='sinkhorn_epsilon_scaling'` calls @@ -389,7 +394,10 @@ More details about the algorithms used are given in the following note. solutions. Note that the greedy version of the Sinkhorn :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the Sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. diff --git a/ot/bregman.py b/ot/bregman.py index 3518bcbca..2aa76ff6e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -56,7 +56,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, By default and when using a regularization parameter that is not too small the default sinkhorn solver should be enough. If you need to use a small regularization to get sharper OT matrices, you should use the - :py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical + :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical errors. This last solver can be very slow in practice and might not even converge to a reasonable OT matrix in a finite time. This is why :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value @@ -65,8 +65,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a fast approximation of the Sinkhorn problem. For use of GPU and gradient - computation we strongly reconmend the :any:`ot.bregman.sinkhorn_log` solver - that will no need to check for numerical problems. + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. Parameters @@ -81,8 +82,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn','sinkhorn_log', + 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see + those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -120,6 +122,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. See Also @@ -197,8 +200,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a fast approximation of the Sinkhorn problem. For use of GPU and gradient - computation we strongly reconmend the :any:`ot.bregman.sinkhorn_log` solver - that will no need to check for numerical problems. + computation with small number of iterations we strongly recommend the + :any:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. Parameters ---------- @@ -212,7 +216,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters + method used for the solver either 'sinkhorn','sinkhorn_log', + 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -251,7 +256,11 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation + algorithms for optimal transport via Sinkhorn iteration, Advances in Neural + Information Processing Systems (NIPS) 31, 2017 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. @@ -493,7 +502,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm :ref:`[2] ` with the + implementation from :ref:`[34] ` Parameters @@ -539,7 +550,10 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. See Also