Skip to content

Commit c809707

Browse files
add examples + tests + warning in entropic solvers + releases
1 parent eb7d814 commit c809707

File tree

7 files changed

+418
-47
lines changed

7 files changed

+418
-47
lines changed

RELEASES.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
# Releases
22

3+
## 0.9.1dev
4+
*Sine April 2023*
5+
6+
7+
#### New features
8+
9+
- Added Proximal Point algorithm to solve GW problems via a new parameter `solver="PPA"` in `ot.gromov.entropic_gromov_wasserstein` + examples (PR #455)
10+
- Added features `warmstart` and `kwargs` in `ot.gromov.entropic_gromov_wasserstein` to respectively perform warmstart on dual potentials and pass parameters to `ot.sinkhorn` (PR #455)
11+
- Added sinkhorn projection based solvers for FGW `ot.gromov.entropic_fused_gromov_wasserstein` and entropic FGW barycenters + examples (PR #455)
12+
- Added features `warmstartT` and `kwargs` to all CG and entropic (F)GW barycenter solvers (PR #455)
13+
- Added entropic semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #455)
14+
15+
#### Closed issues
16+
17+
318
## 0.9.0
419

520
This new release contains so many new features and bug fixes since 0.8.2 that we

examples/gromov/plot_entropic_semirelaxed_fgw.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@
113113
# ---------------------------------------------
114114
#
115115
# We color nodes of the graph on the right - then project its node colors
116-
# based on the optimal transport plan from the entropic srGW matching
116+
# based on the optimal transport plan from the entropic srGW matching.
117+
# We adjust the intensity of links across domains proportionaly to the mass
118+
# sent, adding a minimal intensity of 0.1 if mass sent is not zero.
117119

118120

119121
def draw_graph(G, C, nodes_color_part, Gweights=None,
@@ -187,11 +189,12 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1,
187189
pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2,
188190
node_size=node_size, shiftx=shiftx, seed=seed_G2)
189191
for k1, v1 in pos1.items():
192+
max_Tk1 = np.max(T[k1, :])
190193
for k2, v2 in pos2.items():
191194
if (T[k1, k2] > 0):
192195
pl.plot([pos1[k1][0], pos2[k2][0]],
193196
[pos1[k1][1], pos2[k2][1]],
194-
'-', lw=0.5, alpha=0.3,
197+
'-', lw=0.6, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.),
195198
color=nodes_color_part1[k1])
196199
return pos1, pos2
197200

examples/gromov/plot_fgw.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
Plot Fused-gromov-Wasserstein
55
==============================
66
7-
This example illustrates the computation of FGW for 1D measures [18].
7+
This example first illustrates the computation of FGW for 1D measures estimated
8+
using a Conditional Gradient solver [24].
89
9-
[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
10+
[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
1011
and Courty Nicolas
1112
"Optimal Transport for structured data with application on graphs"
1213
International Conference on Machine Learning (ICML). 2019.
13-
1414
"""
1515

1616
# Author: Titouan Vayer <[email protected]>
@@ -24,11 +24,13 @@
2424
import ot
2525
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
2626

27+
2728
##############################################################################
2829
# Generate data
2930
# -------------
3031

31-
#%% parameters
32+
# parameters
33+
3234
# We create two 1D random measures
3335
n = 20 # number of points in the first distribution
3436
n2 = 30 # number of points in the second distribution
@@ -53,10 +55,9 @@
5355
# Plot data
5456
# ---------
5557

56-
#%% plot the distributions
58+
# plot the distributions
5759

58-
pl.close(10)
59-
pl.figure(10, (7, 7))
60+
pl.figure(1, (7, 7))
6061

6162
pl.subplot(2, 1, 1)
6263

@@ -78,7 +79,7 @@
7879
# Create structure matrices and across-feature distance matrix
7980
# ------------------------------------------------------------
8081

81-
#%% Structure matrices and across-features distance matrix
82+
# Structure matrices and across-features distance matrix
8283
C1 = ot.dist(xs)
8384
C2 = ot.dist(xt)
8485
M = ot.dist(ys, yt)
@@ -90,10 +91,9 @@
9091
# Plot matrices
9192
# -------------
9293

93-
#%%
9494
cmap = 'Reds'
95-
pl.close(10)
96-
pl.figure(10, (5, 5))
95+
96+
pl.figure(2, (5, 5))
9797
fs = 15
9898
l_x = [0, 5, 10, 15]
9999
l_y = [0, 5, 10, 15, 20, 25]
@@ -113,7 +113,6 @@
113113
pl.imshow(C2, cmap=cmap, interpolation='nearest')
114114
pl.title("$C_2$", fontsize=fs)
115115
pl.ylabel("$l$", fontsize=fs)
116-
#pl.ylabel("$l$",fontsize=fs)
117116
pl.xticks(())
118117
pl.yticks(l_y)
119118
ax2.set_aspect('auto')
@@ -133,28 +132,27 @@
133132
# Compute FGW/GW
134133
# --------------
135134

136-
#%% Computing FGW and GW
135+
# Computing FGW and GW
137136
alpha = 1e-3
138137

139138
ot.tic()
140139
Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
141140
ot.toc()
142141

143-
#%reload_ext WGW
142+
# reload_ext WGW
144143
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
145144

146145
##############################################################################
147146
# Visualize transport matrices
148147
# ----------------------------
149148

150-
#%% visu OT matrix
149+
# visu OT matrix
151150
cmap = 'Blues'
152151
fs = 15
153-
pl.figure(2, (13, 5))
152+
pl.figure(3, (13, 5))
154153
pl.clf()
155154
pl.subplot(1, 3, 1)
156155
pl.imshow(Got, cmap=cmap, interpolation='nearest')
157-
#pl.xlabel("$y$",fontsize=fs)
158156
pl.ylabel("$i$", fontsize=fs)
159157
pl.xticks(())
160158

0 commit comments

Comments
 (0)