-
Notifications
You must be signed in to change notification settings - Fork 529
[MRG] Efficient Discrete Multi Marginal Optimal Transport #454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 29 commits
0327b5c
99f4ae3
27878b7
4a3a4f1
94e0f44
2957510
708b756
7c813a7
707152c
02bd955
81ab727
706d6a5
4e6f693
74c87dc
6730631
08bb919
29d16f4
4226eee
3c7ab34
7452379
9c360bb
21f16c5
697036d
70326a6
b4b4609
be09209
8d16b0f
6de193c
e98c7ee
7339e8a
fd444b7
bd2d2ec
f531b9e
c1ccd46
99d2e86
b3cb896
2d22fc9
018313b
a7bde66
b370202
24a69c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# -*- coding: utf-8 -*- | ||
r""" | ||
=============================================================================== | ||
Computing d-dimensional Barycenters via d-MMOT | ||
=============================================================================== | ||
|
||
When the cost is discretized (Monge), the d-MMOT solver can more quickly | ||
compute and minimize the distance between many distributions without the need | ||
for intermediate barycenter computations. This example compares the time to | ||
identify, and the quality of, solutions for the d-MMOT problem using a | ||
primal/dual algorithm and classical LP barycenter approaches. | ||
""" | ||
|
||
# Author: Ronak Mehta <[email protected]> | ||
# Xizheng Yu <[email protected]> | ||
# | ||
# License: MIT License | ||
|
||
# %% | ||
# Generating 2 distributions | ||
# ----- | ||
import numpy as np | ||
import matplotlib.pyplot as pl | ||
import ot | ||
|
||
np.random.seed(0) | ||
|
||
n = 100 | ||
d = 2 | ||
# Gaussian distributions | ||
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std | ||
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) | ||
A = np.vstack((a1, a2)).T | ||
x = np.arange(n, dtype=np.float64) | ||
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') | ||
|
||
pl.figure(1, figsize=(6.4, 3)) | ||
pl.plot(x, a1, 'b', label='Source distribution') | ||
pl.plot(x, a2, 'r', label='Target distribution') | ||
pl.legend() | ||
|
||
# %% | ||
# Minimize the distances among distributions, identify the Barycenter | ||
# ----- | ||
# The objective being minimized is different for both methods, so the objective | ||
# values cannot be compared. | ||
|
||
print('LP Iterations:') | ||
ot.tic() | ||
alpha = 1 # /d # 0<=alpha<=1 | ||
weights = np.array(d * [alpha]) | ||
lp_bary, lp_log = ot.lp.barycenter( | ||
A, M, weights, solver='interior-point', verbose=False, log=True) | ||
print('Time\t: ', ot.toc('')) | ||
print('Obj\t: ', lp_log['fun']) | ||
|
||
print('') | ||
print('Discrete MMOT Algorithm:') | ||
ot.tic() | ||
# dmmot_obj, log = ot.lp.discrete_mmot(A.T, n, d) | ||
barys, log = ot.lp.discrete_mmot_converge( | ||
A, niters=3000, lr=0.000002, log=True) | ||
dmmot_obj = log['primal objective'] | ||
print('Time\t: ', ot.toc('')) | ||
print('Obj\t: ', dmmot_obj) | ||
|
||
|
||
# %% | ||
# Compare Barycenters in both methods | ||
# --------- | ||
pl.figure(1, figsize=(6.4, 3)) | ||
for i in range(len(barys)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you shoumd compare it to the l2 (np.mean) barycenter because your barycenter looks very similar |
||
if i == 0: | ||
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be nice to see visually if you converged by plotting all the individual distributions (seems like you did because your "barycenter" ). maybe you could call it "Monge MMOT minimization" instead of discrete MMOT? |
||
else: | ||
continue | ||
# pl.plot(x, barys[i], 'g-*') | ||
pl.plot(x, lp_bary, 'k-', label='LP Barycenter') | ||
pl.plot(x, a1, 'b', label='Source distribution') | ||
pl.plot(x, a2, 'r', label='Target distribution') | ||
pl.title('Barycenters') | ||
pl.legend() | ||
|
||
# %% | ||
# More than 2 distributions | ||
# -------------------------------------------------- | ||
# Generate 7 pseudorandom gaussian distributions with 50 bins. | ||
n = 50 # nb bins | ||
d = 7 | ||
vecsize = n * d | ||
|
||
data = [] | ||
for i in range(d): | ||
m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1) | ||
a = ot.datasets.make_1D_gauss(n, m=m, s=5) | ||
data.append(a) | ||
|
||
x = np.arange(n, dtype=np.float64) | ||
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') | ||
A = np.vstack(data).T | ||
|
||
pl.figure(1, figsize=(6.4, 3)) | ||
for i in range(len(data)): | ||
pl.plot(x, data[i]) | ||
|
||
pl.title('Distributions') | ||
pl.legend() | ||
|
||
# %% | ||
# Minimizing Distances Among Many Distributions | ||
# --------------- | ||
# The objective being minimized is different for both methods, so the objective | ||
# values cannot be compared. | ||
|
||
# Perform gradient descent optimization using the d-MMOT method. | ||
barys = ot.lp.discrete_mmot_converge(A, niters=9000, lr=0.00001) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
|
||
# after minimization, any distribution can be used as a estimate of barycenter. | ||
bary = barys[0] | ||
|
||
# Compute 1D Wasserstein barycenter using the LP method | ||
weights = ot.unif(d) | ||
lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', | ||
verbose=True, log=True) | ||
|
||
# %% | ||
# Compare Barycenters in both methods | ||
# --------- | ||
pl.figure(1, figsize=(6.4, 3)) | ||
pl.plot(x, bary, 'g-*', label='Discrete MMOT') | ||
pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') | ||
pl.title('Barycenters') | ||
pl.legend() | ||
|
||
# %% | ||
# Compare with original distributions | ||
# --------- | ||
pl.figure(1, figsize=(6.4, 3)) | ||
for i in range(len(data)): | ||
pl.plot(x, data[i]) | ||
for i in range(len(barys)): | ||
if i == 0: | ||
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') | ||
else: | ||
continue | ||
# pl.plot(x, barys[i], 'g') | ||
pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') | ||
# pl.plot(x, bary, 'g', label='Discrete MMOT') | ||
pl.title('Barycenters') | ||
pl.legend() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
both the objective value ad the norm of the graient increase at the end which is very surprising since it is supposed to be a gardient decsnet no?