|
4 | 4 | Plot Fused-gromov-Wasserstein
|
5 | 5 | ==============================
|
6 | 6 |
|
7 |
| -This example illustrates the computation of FGW for 1D measures [18]. |
| 7 | +This example first illustrates the computation of FGW for 1D measures estimated |
| 8 | +using a Conditional Gradient solver [24]. |
8 | 9 |
|
9 |
| -[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain |
| 10 | +[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain |
10 | 11 | and Courty Nicolas
|
11 | 12 | "Optimal Transport for structured data with application on graphs"
|
12 | 13 | International Conference on Machine Learning (ICML). 2019.
|
13 |
| -
|
14 | 14 | """
|
15 | 15 |
|
16 | 16 | # Author: Titouan Vayer <[email protected]>
|
|
24 | 24 | import ot
|
25 | 25 | from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
|
26 | 26 |
|
| 27 | + |
27 | 28 | ##############################################################################
|
28 | 29 | # Generate data
|
29 | 30 | # -------------
|
30 | 31 |
|
31 |
| -#%% parameters |
| 32 | +# parameters |
| 33 | + |
32 | 34 | # We create two 1D random measures
|
33 | 35 | n = 20 # number of points in the first distribution
|
34 | 36 | n2 = 30 # number of points in the second distribution
|
|
53 | 55 | # Plot data
|
54 | 56 | # ---------
|
55 | 57 |
|
56 |
| -#%% plot the distributions |
| 58 | +# plot the distributions |
57 | 59 |
|
58 |
| -pl.close(10) |
59 |
| -pl.figure(10, (7, 7)) |
| 60 | +pl.figure(1, (7, 7)) |
60 | 61 |
|
61 | 62 | pl.subplot(2, 1, 1)
|
62 | 63 |
|
|
78 | 79 | # Create structure matrices and across-feature distance matrix
|
79 | 80 | # ------------------------------------------------------------
|
80 | 81 |
|
81 |
| -#%% Structure matrices and across-features distance matrix |
| 82 | +# Structure matrices and across-features distance matrix |
82 | 83 | C1 = ot.dist(xs)
|
83 | 84 | C2 = ot.dist(xt)
|
84 | 85 | M = ot.dist(ys, yt)
|
|
90 | 91 | # Plot matrices
|
91 | 92 | # -------------
|
92 | 93 |
|
93 |
| -#%% |
94 | 94 | cmap = 'Reds'
|
95 |
| -pl.close(10) |
96 |
| -pl.figure(10, (5, 5)) |
| 95 | + |
| 96 | +pl.figure(2, (5, 5)) |
97 | 97 | fs = 15
|
98 | 98 | l_x = [0, 5, 10, 15]
|
99 | 99 | l_y = [0, 5, 10, 15, 20, 25]
|
|
113 | 113 | pl.imshow(C2, cmap=cmap, interpolation='nearest')
|
114 | 114 | pl.title("$C_2$", fontsize=fs)
|
115 | 115 | pl.ylabel("$l$", fontsize=fs)
|
116 |
| -#pl.ylabel("$l$",fontsize=fs) |
117 | 116 | pl.xticks(())
|
118 | 117 | pl.yticks(l_y)
|
119 | 118 | ax2.set_aspect('auto')
|
|
133 | 132 | # Compute FGW/GW
|
134 | 133 | # --------------
|
135 | 134 |
|
136 |
| -#%% Computing FGW and GW |
| 135 | +# Computing FGW and GW |
137 | 136 | alpha = 1e-3
|
138 | 137 |
|
139 | 138 | ot.tic()
|
140 | 139 | Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
|
141 | 140 | ot.toc()
|
142 | 141 |
|
143 |
| -#%reload_ext WGW |
| 142 | +# reload_ext WGW |
144 | 143 | Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
|
145 | 144 |
|
146 | 145 | ##############################################################################
|
147 | 146 | # Visualize transport matrices
|
148 | 147 | # ----------------------------
|
149 | 148 |
|
150 |
| -#%% visu OT matrix |
| 149 | +# visu OT matrix |
151 | 150 | cmap = 'Blues'
|
152 | 151 | fs = 15
|
153 |
| -pl.figure(2, (13, 5)) |
| 152 | +pl.figure(3, (13, 5)) |
154 | 153 | pl.clf()
|
155 | 154 | pl.subplot(1, 3, 1)
|
156 | 155 | pl.imshow(Got, cmap=cmap, interpolation='nearest')
|
157 |
| -#pl.xlabel("$y$",fontsize=fs) |
158 | 156 | pl.ylabel("$i$", fontsize=fs)
|
159 | 157 | pl.xticks(())
|
160 | 158 |
|
|
0 commit comments