Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
0327b5c
add demd.py to ot, add plot_demd_*.py to examples, updated init.py in…
xzyu02 Apr 4, 2023
99f4ae3
Merge branch 'PythonOT:master' into demd
xzyu02 Apr 4, 2023
27878b7
update REAMDME.md with citation to iclr23 paper and example link
xzyu02 Apr 5, 2023
4a3a4f1
chaneg directory of examples, build successful
xzyu02 Apr 5, 2023
94e0f44
fix small latex bug
xzyu02 Apr 5, 2023
2957510
update all.rst, examples and demd have passed pep8 and pyflake
xzyu02 Apr 5, 2023
708b756
add more detailed comments for examples
xzyu02 Apr 5, 2023
7c813a7
TODO: test module for demd, wrong demd index after build
xzyu02 Apr 5, 2023
707152c
add test module
xzyu02 Apr 8, 2023
02bd955
Merge branch 'PythonOT:master' into demd
xzyu02 Apr 8, 2023
81ab727
add contributors
xzyu02 Apr 8, 2023
706d6a5
pass pyflake checks, pass pep8
xzyu02 Apr 8, 2023
4e6f693
added the PR to the RELEASES.md file
xzyu02 Apr 8, 2023
74c87dc
merge from master
xzyu02 Apr 20, 2023
6730631
Merge branch 'master' into demd
rflamary Apr 24, 2023
08bb919
temporal changes with logs
xzyu02 May 4, 2023
29d16f4
init changes
xzyu02 May 7, 2023
4226eee
merge examples, demd -> lp.dmmot
xzyu02 May 7, 2023
3c7ab34
bug fix in plot_dmmot, some commenting/documenting edits
ronakrm May 17, 2023
7452379
dmmot example cleanup, some comments/plotting edits
ronakrm May 17, 2023
9c360bb
add dist_monge method
xzyu02 Jun 1, 2023
21f16c5
merge from incoming
xzyu02 Jun 1, 2023
697036d
all dmmot methods takes (n, d) shape A as input (follows POT style)
xzyu02 Jun 1, 2023
70326a6
passed pep8 and pyflake checks
xzyu02 Jun 1, 2023
b4b4609
merge from master
xzyu02 Jun 1, 2023
be09209
Merge branch 'master' into demd
rflamary Jun 12, 2023
8d16b0f
Merge branch 'master' into demd
xzyu02 Jun 12, 2023
6de193c
resolve test fail issue
xzyu02 Jun 12, 2023
e98c7ee
fix pep8 error
xzyu02 Jun 13, 2023
7339e8a
resolve issues from last review, pyflake and pep8 checked
xzyu02 Jul 5, 2023
fd444b7
add lr decay
xzyu02 Jul 7, 2023
bd2d2ec
Merge branch 'master' into demd
rflamary Jul 10, 2023
f531b9e
add more examples, ground cost options, test for uniqueness
xzyu02 Jul 26, 2023
c1ccd46
Merge branch 'demd' of github.com:x12hengyu/POT into demd
xzyu02 Jul 26, 2023
99d2e86
Merge branch 'master' into demd
xzyu02 Jul 26, 2023
b3cb896
remove additional experiment setting, not needed in this PR
xzyu02 Jul 28, 2023
2d22fc9
fixed line 14 1 blank line
xzyu02 Jul 29, 2023
018313b
Merge branch 'master' into demd
rflamary Aug 2, 2023
a7bde66
fix gradient computation link
xzyu02 Aug 2, 2023
b370202
Merge branch 'demd' of github.com:x12hengyu/POT into demd
xzyu02 Aug 2, 2023
24a69c0
Update ot/lp/dmmot.py
rflamary Aug 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ The contributors to this library are:
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein)
* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)

## Acknowledgments

Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ POT provides the following generic OT solvers (links to examples):
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.

POT provides the following Machine Learning related solvers:
Expand Down Expand Up @@ -312,3 +313,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR).

[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019.

[52] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR).

[53] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019.

2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
- Added entropic semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #455)
- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455)

- Added feature Efficient Discrete Multi Marginal Optimal Transport Regularization + examples (PR #454)

#### Closed issues
- Fix circleci-redirector action and codecov (PR #460)
- Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457)
Expand Down
150 changes: 150 additions & 0 deletions examples/others/plot_d-mmot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
r"""
===============================================================================
Computing d-dimensional Barycenters via d-MMOT
===============================================================================

When the cost is discretized (Monge), the d-MMOT solver can more quickly
compute and minimize the distance between many distributions without the need
for intermediate barycenter computations. This example compares the time to
identify, and the quality of, solutions for the d-MMOT problem using a
primal/dual algorithm and classical LP barycenter approaches.
"""

# Author: Ronak Mehta <[email protected]>
# Xizheng Yu <[email protected]>
#
# License: MIT License

# %%
# Generating 2 distributions
# -----
import numpy as np
import matplotlib.pyplot as pl
import ot

np.random.seed(0)

n = 100
d = 2
# Gaussian distributions
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
A = np.vstack((a1, a2)).T
x = np.arange(n, dtype=np.float64)
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')

pl.figure(1, figsize=(6.4, 3))
pl.plot(x, a1, 'b', label='Source distribution')
pl.plot(x, a2, 'r', label='Target distribution')
pl.legend()

# %%
# Minimize the distances among distributions, identify the Barycenter
# -----
# The objective being minimized is different for both methods, so the objective
# values cannot be compared.

print('LP Iterations:')
ot.tic()
alpha = 1 # /d # 0<=alpha<=1
weights = np.array(d * [alpha])
lp_bary, lp_log = ot.lp.barycenter(
A, M, weights, solver='interior-point', verbose=False, log=True)
print('Time\t: ', ot.toc(''))
print('Obj\t: ', lp_log['fun'])

print('')
print('Discrete MMOT Algorithm:')
ot.tic()
# dmmot_obj, log = ot.lp.discrete_mmot(A.T, n, d)
barys, log = ot.lp.discrete_mmot_converge(
A, niters=3000, lr=0.000002, log=True)
dmmot_obj = log['primal objective']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both the objective value ad the norm of the graient increase at the end which is very surprising since it is supposed to be a gardient decsnet no?

print('Time\t: ', ot.toc(''))
print('Obj\t: ', dmmot_obj)


# %%
# Compare Barycenters in both methods
# ---------
pl.figure(1, figsize=(6.4, 3))
for i in range(len(barys)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shoumd compare it to the l2 (np.mean) barycenter because your barycenter looks very similar

if i == 0:
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to see visually if you converged by plotting all the individual distributions (seems like you did because your "barycenter" ). maybe you could call it "Monge MMOT minimization" instead of discrete MMOT?

else:
continue
# pl.plot(x, barys[i], 'g-*')
pl.plot(x, lp_bary, 'k-', label='LP Barycenter')
pl.plot(x, a1, 'b', label='Source distribution')
pl.plot(x, a2, 'r', label='Target distribution')
pl.title('Barycenters')
pl.legend()

# %%
# More than 2 distributions
# --------------------------------------------------
# Generate 7 pseudorandom gaussian distributions with 50 bins.
n = 50 # nb bins
d = 7
vecsize = n * d

data = []
for i in range(d):
m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1)
a = ot.datasets.make_1D_gauss(n, m=m, s=5)
data.append(a)

x = np.arange(n, dtype=np.float64)
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')
A = np.vstack(data).T

pl.figure(1, figsize=(6.4, 3))
for i in range(len(data)):
pl.plot(x, data[i])

pl.title('Distributions')
pl.legend()

# %%
# Minimizing Distances Among Many Distributions
# ---------------
# The objective being minimized is different for both methods, so the objective
# values cannot be compared.

# Perform gradient descent optimization using the d-MMOT method.
barys = ot.lp.discrete_mmot_converge(A, niters=9000, lr=0.00001)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here


# after minimization, any distribution can be used as a estimate of barycenter.
bary = barys[0]

# Compute 1D Wasserstein barycenter using the LP method
weights = ot.unif(d)
lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point',
verbose=True, log=True)

# %%
# Compare Barycenters in both methods
# ---------
pl.figure(1, figsize=(6.4, 3))
pl.plot(x, bary, 'g-*', label='Discrete MMOT')
pl.plot(x, lp_bary, 'k-', label='LP Wasserstein')
pl.title('Barycenters')
pl.legend()

# %%
# Compare with original distributions
# ---------
pl.figure(1, figsize=(6.4, 3))
for i in range(len(data)):
pl.plot(x, data[i])
for i in range(len(barys)):
if i == 0:
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
else:
continue
# pl.plot(x, barys[i], 'g')
pl.plot(x, lp_bary, 'k-', label='LP Wasserstein')
# pl.plot(x, bary, 'g', label='Discrete MMOT')
pl.title('Barycenters')
pl.legend()
4 changes: 3 additions & 1 deletion ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from . import cvx
from .cvx import barycenter
from .dmmot import *

# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
Expand All @@ -30,7 +31,8 @@

__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter',
'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle']
'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle',
'discrete_mmot', 'discrete_mmot_converge']


def check_number_threads(numThreads):
Expand Down
Loading