Skip to content

Commit 441a3a7

Browse files
authored
Merge branch 'master' into fix-line-search-zero-cost
2 parents 25c7d88 + 7856700 commit 441a3a7

File tree

12 files changed

+1382
-583
lines changed

12 files changed

+1382
-583
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ POT provides the following generic OT solvers (links to examples):
4545
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
4646
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
4747
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
48+
* Smooth Strongly Convex Nearest Brenier Potentials [58], with an extension to bounding potentials using [59].
4849

4950
POT provides the following Machine Learning related solvers:
5051

@@ -329,3 +330,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
329330
distances between Gaussian distributions](https://hal.science/hal-03197398v2/file/main.pdf). Journal of Applied Probability, 59(4),
330331
1178-1198.
331332

333+
[58] Paty F-P., d’Aspremont 1., & Cuturi M. (2020). [Regularity as regularization:Smooth and strongly convex brenier potentials in optimal transport.](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) In International Conference on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020.
334+
335+
[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017.
336+

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
## 0.9.2dev
44

55
#### New features
6-
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)
6+
+ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526)
7+
+ Tweaked `get_backend` to ignore `None` inputs (PR #525)
78
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
89

910
#### Closed issues

docs/requirements_rtd.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ matplotlib
1111
autograd
1212
pymanopt
1313
cvxopt
14-
scikit-learn
14+
scikit-learn
15+
cvxpy

docs/source/all.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ API and modules
2525
gnn
2626
gromov
2727
lp
28+
mapping
2829
optim
2930
partial
3031
plot

examples/others/plot_SSNB.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
=====================================================
4+
Smooth and Strongly Convex Nearest Brenier Potentials
5+
=====================================================
6+
7+
This example is designed to show how to use SSNB [58] in POT.
8+
SSNB computes an l-strongly convex potential :math:`\varphi` with an L-Lipschitz gradient such that
9+
:math:`\nabla \varphi \# \mu \approx \nu`. This regularity can be enforced only on the components of a partition
10+
of the ambient space, which is a relaxation compared to imposing global regularity.
11+
12+
In this example, we consider a source measure :math:`\mu_s` which is the uniform measure on the unit square in
13+
:math:`\mathbb{R}^2`, and the target measure :math:`\mu_t` which is the image of :math:`\mu_x` by
14+
:math:`T(x_1, x_2) = (x_1 + 2\mathrm{sign}(x_2), 2 * x_2)`. The map :math:`T` is non-smooth, and we wish to approximate
15+
it using a "Brenier-style" map :math:`\nabla \varphi` which is regular on the partition
16+
:math:`\lbrace x_1 <=0, x_1>0\rbrace`, which is well adapted to this particular dataset.
17+
18+
We represent the gradients of the "bounding potentials" :math:`\varphi_l, \varphi_u` (from [59], Theorem 3.14),
19+
which bound any SSNB potential which is optimal in the sense of [58], Definition 1:
20+
21+
.. math::
22+
\varphi \in \mathrm{argmin}_{\varphi \in \mathcal{F}}\ \mathrm{W}_2(\nabla \varphi \#\mu_s, \mu_t),
23+
24+
where :math:`\mathcal{F}` is the space functions that are on every set :math:`E_k` l-strongly convex
25+
with an L-Lipschitz gradient, given :math:`(E_k)_{k \in [K]}` a partition of the ambient source space.
26+
27+
We perform the optimisation on a low amount of fitting samples and with few iterations,
28+
since solving the SSNB problem is quite computationally expensive.
29+
30+
THIS EXAMPLE REQUIRES CVXPY
31+
32+
.. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization:
33+
Smooth and strongly convex brenier potentials in optimal transport. In International Conference
34+
on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020.
35+
36+
.. [59] Adrien B Taylor. Convex interpolation and performance estimation of first-order methods for
37+
convex optimization. PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium,
38+
2017.
39+
"""
40+
41+
# Author: Eloi Tanguy <[email protected]>
42+
# License: MIT License
43+
44+
# sphinx_gallery_thumbnail_number = 4
45+
46+
import matplotlib.pyplot as plt
47+
import numpy as np
48+
import ot
49+
50+
# %%
51+
# Generating the fitting data
52+
n_fitting_samples = 30
53+
rng = np.random.RandomState(seed=0)
54+
Xs = rng.uniform(-1, 1, size=(n_fitting_samples, 2))
55+
Xs_classes = (Xs[:, 0] < 0).astype(int)
56+
Xt = np.stack([Xs[:, 0] + 2 * np.sign(Xs[:, 0]), 2 * Xs[:, 1]], axis=-1)
57+
58+
plt.scatter(Xs[Xs_classes == 0, 0], Xs[Xs_classes == 0, 1], c='blue', label='source class 0')
59+
plt.scatter(Xs[Xs_classes == 1, 0], Xs[Xs_classes == 1, 1], c='dodgerblue', label='source class 1')
60+
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
61+
plt.axis('equal')
62+
plt.title('Splitting sphere dataset')
63+
plt.legend(loc='upper right')
64+
plt.show()
65+
66+
# %%
67+
# Plotting image of barycentric projection (SSNB initialisation values)
68+
plt.clf()
69+
pi = ot.emd(ot.unif(n_fitting_samples), ot.unif(n_fitting_samples), ot.dist(Xs, Xt))
70+
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source')
71+
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
72+
bar_img = pi @ Xt
73+
for i in range(n_fitting_samples):
74+
plt.plot([Xs[i, 0], bar_img[i, 0]], [Xs[i, 1], bar_img[i, 1]], color='black', alpha=.5)
75+
plt.title('Images of in-data source samples by the barycentric map')
76+
plt.legend(loc='upper right')
77+
plt.axis('equal')
78+
plt.show()
79+
80+
# %%
81+
# Fitting the Nearest Brenier Potential
82+
L = 3 # need L > 2 to allow the 2*y term, default is 1.4
83+
phi, G = ot.mapping.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, init_method='barycentric',
84+
gradient_lipschitz_constant=L)
85+
86+
# %%
87+
# Plotting the images of the source data
88+
plt.clf()
89+
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source')
90+
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
91+
for i in range(n_fitting_samples):
92+
plt.plot([Xs[i, 0], G[i, 0]], [Xs[i, 1], G[i, 1]], color='black', alpha=.5)
93+
plt.title('Images of in-data source samples by the fitted SSNB')
94+
plt.legend(loc='upper right')
95+
plt.axis('equal')
96+
plt.show()
97+
98+
# %%
99+
# Computing the predictions (images by nabla phi) for random samples of the source distribution
100+
n_predict_samples = 50
101+
Ys = rng.uniform(-1, 1, size=(n_predict_samples, 2))
102+
Ys_classes = (Ys[:, 0] < 0).astype(int)
103+
phi_lu, G_lu = ot.mapping.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes,
104+
gradient_lipschitz_constant=L)
105+
106+
# %%
107+
# Plot predictions for the gradient of the lower-bounding potential
108+
plt.clf()
109+
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source')
110+
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
111+
for i in range(n_predict_samples):
112+
plt.plot([Ys[i, 0], G_lu[0, i, 0]], [Ys[i, 1], G_lu[0, i, 1]], color='black', alpha=.5)
113+
plt.title('Images of new source samples by $\\nabla \\varphi_l$')
114+
plt.legend(loc='upper right')
115+
plt.axis('equal')
116+
plt.show()
117+
118+
# %%
119+
# Plot predictions for the gradient of the upper-bounding potential
120+
plt.clf()
121+
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source')
122+
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
123+
for i in range(n_predict_samples):
124+
plt.plot([Ys[i, 0], G_lu[1, i, 0]], [Ys[i, 1], G_lu[1, i, 1]], color='black', alpha=.5)
125+
plt.title('Images of new source samples by $\\nabla \\varphi_u$')
126+
plt.legend(loc='upper right')
127+
plt.axis('equal')
128+
plt.show()

ot/__init__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
:py:mod:`ot.utils`, :py:mod:`ot.datasets`,
66
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
77
:py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath`
8-
, :py:mod:`ot.unbalanced`.
8+
, :py:mod:`ot.unbalanced`, :py:mod`ot.mapping`.
99
The following sub-modules are not imported due to additional dependencies:
1010
- :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
1111
- :any:`ot.plot` : depends on :code:`matplotlib`
@@ -37,35 +37,35 @@
3737
from . import gaussian
3838

3939
# OT functions
40-
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
41-
binary_search_circle, wasserstein_circle,
42-
semidiscrete_wasserstein2_unif_circle)
40+
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
41+
binary_search_circle, wasserstein_circle,
42+
semidiscrete_wasserstein2_unif_circle)
4343
from .bregman import sinkhorn, sinkhorn2, barycenter
4444
from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
4545
sinkhorn_unbalanced2)
4646
from .da import sinkhorn_lpl1_mm
47-
from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance,
47+
from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance,
4848
sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif)
4949
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
50-
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
50+
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
5151
from .weak import weak_optimal_transport
5252
from .factored import factored_optimal_transport
5353
from .solvers import solve
5454

5555
# utils functions
5656
from .utils import dist, unif, tic, toc, toq
5757

58-
__version__ = "0.9.1"
58+
__version__ = "0.9.2dev"
5959

6060
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
6161
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
6262
'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
6363
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
6464
'sinkhorn_unbalanced', 'barycenter_unbalanced',
6565
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
66-
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
67-
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
68-
'factored_optimal_transport', 'solve',
66+
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
67+
'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
68+
'factored_optimal_transport', 'solve',
6969
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
7070
'binary_search_circle', 'wasserstein_circle',
7171
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']

0 commit comments

Comments
 (0)