@@ -694,6 +694,61 @@ def test_entropic_fgw_dtype_device(nx):
694
694
nx .assert_same_dtype_device (C1b , fgw_valb )
695
695
696
696
697
+ def test_entropic_fgw_barycenter (nx ):
698
+ ns = 5
699
+ nt = 10
700
+
701
+ Xs , ys = ot .datasets .make_data_classif ('3gauss' , ns , random_state = 42 )
702
+ Xt , yt = ot .datasets .make_data_classif ('3gauss2' , nt , random_state = 42 )
703
+
704
+ ys = np .random .randn (Xs .shape [0 ], 2 )
705
+ yt = np .random .randn (Xt .shape [0 ], 2 )
706
+
707
+ C1 = ot .dist (Xs )
708
+ C2 = ot .dist (Xt )
709
+ p1 = ot .unif (ns )
710
+ p2 = ot .unif (nt )
711
+ n_samples = 2
712
+ p = ot .unif (n_samples )
713
+
714
+ ysb , ytb , C1b , C2b , p1b , p2b , pb = nx .from_numpy (ys , yt , C1 , C2 , p1 , p2 , p )
715
+
716
+ X , C , log = ot .gromov .entropic_fused_gromov_barycenters (
717
+ n_samples , [ys , yt ], [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ], 'square_loss' , 0.1 ,
718
+ max_iter = 50 , tol = 1e-3 , verbose = True , warmstartT = True , random_state = 42 ,
719
+ solver = 'PPA' , numItermax = 1 , log = True
720
+ )
721
+ Xb , Cb = ot .gromov .entropic_fused_gromov_barycenters (
722
+ n_samples , [ysb , ytb ], [C1b , C2b ], [p1b , p2b ], pb , [.5 , .5 ], 'square_loss' , 0.1 ,
723
+ max_iter = 50 , tol = 1e-3 , verbose = False , warmstartT = True , random_state = 42 ,
724
+ solver = 'PPA' , numItermax = 1 , log = False )
725
+ Xb , Cb = nx .to_numpy (Xb , Cb )
726
+
727
+ np .testing .assert_allclose (C , Cb , atol = 1e-06 )
728
+ np .testing .assert_allclose (Cb .shape , (n_samples , n_samples ))
729
+ np .testing .assert_allclose (X , Xb , atol = 1e-06 )
730
+ np .testing .assert_allclose (Xb .shape , (n_samples , ys .shape [1 ]))
731
+
732
+ # test with 'kl_loss' and log=True
733
+ X , C , log = ot .gromov .entropic_fused_gromov_barycenters (
734
+ n_samples , [ys , yt ], [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ], 'kl_loss' , 0.1 ,
735
+ max_iter = 50 , tol = 1e-3 , verbose = False , warmstartT = False , random_state = 42 ,
736
+ solver = 'PPA' , numItermax = 1 , log = True
737
+ )
738
+ Xb , Cb , logb = ot .gromov .entropic_fused_gromov_barycenters (
739
+ n_samples , [ysb , ytb ], [C1b , C2b ], [p1b , p2b ], pb , [.5 , .5 ], 'kl_loss' , 0.1 ,
740
+ max_iter = 50 , tol = 1e-3 , verbose = False , warmstartT = False , random_state = 42 ,
741
+ solver = 'PPA' , numItermax = 1 , log = True )
742
+ Xb , Cb = nx .to_numpy (Xb , Cb )
743
+
744
+ np .testing .assert_allclose (C , Cb , atol = 1e-06 )
745
+ np .testing .assert_allclose (Cb .shape , (n_samples , n_samples ))
746
+ np .testing .assert_allclose (X , Xb , atol = 1e-06 )
747
+ np .testing .assert_allclose (Xb .shape , (n_samples , ys .shape [1 ]))
748
+ np .testing .assert_array_almost_equal (log ['err_feature' ], nx .to_numpy (* logb ['err_feature' ]))
749
+ np .testing .assert_array_almost_equal (log ['err_structure' ], nx .to_numpy (* logb ['err_structure' ]))
750
+
751
+
697
752
def test_pointwise_gromov (nx ):
698
753
n_samples = 5 # nb samples
699
754
@@ -1173,6 +1228,9 @@ def test_fgw_barycenter(nx):
1173
1228
1174
1229
C1 = ot .dist (Xs )
1175
1230
C2 = ot .dist (Xt )
1231
+ C1 /= C1 .max ()
1232
+ C2 /= C2 .max ()
1233
+
1176
1234
p1 , p2 = ot .unif (ns ), ot .unif (nt )
1177
1235
n_samples = 3
1178
1236
p = ot .unif (n_samples )
@@ -1186,6 +1244,7 @@ def test_fgw_barycenter(nx):
1186
1244
1187
1245
xalea = np .random .randn (n_samples , 2 )
1188
1246
init_C = ot .dist (xalea , xalea )
1247
+ init_C /= init_C .max ()
1189
1248
init_Cb = nx .from_numpy (init_C )
1190
1249
1191
1250
Xb , Cb = ot .gromov .fgw_barycenters (
@@ -1206,9 +1265,18 @@ def test_fgw_barycenter(nx):
1206
1265
p = pb , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3 ,
1207
1266
warmstartT = True , log = True , random_state = 98765
1208
1267
)
1209
- Xb , Cb = nx .to_numpy (Xb ), nx .to_numpy (Cb )
1210
- np .testing .assert_allclose (Cb .shape , (n_samples , n_samples ))
1211
- np .testing .assert_allclose (Xb .shape , (n_samples , ys .shape [1 ]))
1268
+ X , C = nx .to_numpy (Xb ), nx .to_numpy (Cb )
1269
+ np .testing .assert_allclose (C .shape , (n_samples , n_samples ))
1270
+ np .testing .assert_allclose (X .shape , (n_samples , ys .shape [1 ]))
1271
+
1272
+ # add test with 'kl_loss'
1273
+ X , C = ot .gromov .fgw_barycenters (
1274
+ n_samples , [ys , yt ], [C1 , C2 ], [p1 , p2 ], [.5 , .5 ], 0.5 ,
1275
+ fixed_structure = False , fixed_features = False , p = p , loss_fun = 'kl_loss' ,
1276
+ max_iter = 100 , tol = 1e-3 , init_C = C , init_X = X , warmstartT = True , random_state = 12345
1277
+ )
1278
+ np .testing .assert_allclose (C .shape , (n_samples , n_samples ))
1279
+ np .testing .assert_allclose (X .shape , (n_samples , ys .shape [1 ]))
1212
1280
1213
1281
1214
1282
def test_gromov_wasserstein_linear_unmixing (nx ):
@@ -2277,3 +2345,49 @@ def test_entropic_semirelaxed_fgw_dtype_device(nx):
2277
2345
2278
2346
nx .assert_same_dtype_device (C1b , Gb )
2279
2347
nx .assert_same_dtype_device (C1b , fgw_valb )
2348
+
2349
+
2350
+ def test_not_implemented_solver ():
2351
+ # test sinkhorn
2352
+ n_samples = 5 # nb samples
2353
+ mu_s = np .array ([0 , 0 ])
2354
+ cov_s = np .array ([[1 , 0 ], [0 , 1 ]])
2355
+
2356
+ xs = ot .datasets .make_2D_samples_gauss (n_samples , mu_s , cov_s , random_state = 42 )
2357
+ xt = xs [::- 1 ].copy ()
2358
+
2359
+ ys = np .random .randn (xs .shape [0 ], 2 )
2360
+ yt = ys [::- 1 ].copy ()
2361
+
2362
+ p = ot .unif (n_samples )
2363
+ q = ot .unif (n_samples )
2364
+
2365
+ C1 = ot .dist (xs , xs )
2366
+ C2 = ot .dist (xt , xt )
2367
+
2368
+ C1 /= C1 .max ()
2369
+ C2 /= C2 .max ()
2370
+ M = ot .dist (ys , yt )
2371
+
2372
+ solver = 'not_implemented'
2373
+ # entropic gw and fgw
2374
+ with pytest .raises (ValueError ):
2375
+ ot .gromov .entropic_gromov_wasserstein (
2376
+ C1 , C2 , p , q , 'square_loss' , epsilon = 1e-1 , solver = solver )
2377
+ with pytest .raises (ValueError ):
2378
+ ot .gromov .entropic_fused_gromov_wasserstein (
2379
+ M , C1 , C2 , p , q , 'square_loss' , epsilon = 1e-1 , solver = solver )
2380
+
2381
+ # exact and entropic srgw and srfgw loss functions
2382
+ loss_fun = 'kl_loss'
2383
+ with pytest .raises (NotImplementedError ):
2384
+ ot .gromov .semirelaxed_gromov_wasserstein (
2385
+ C1 , C2 , p , loss_fun , armijo = False )
2386
+ with pytest .raises (NotImplementedError ):
2387
+ ot .gromov .entropic_semirelaxed_gromov_wasserstein (
2388
+ C1 , C2 , p , loss_fun , epsilon = 0.1 )
2389
+ with pytest .raises (NotImplementedError ):
2390
+ ot .gromov .semirelaxed_fused_gromov_wasserstein2 (M , C1 , C2 , p , loss_fun )
2391
+ with pytest .raises (NotImplementedError ):
2392
+ ot .gromov .entropic_semirelaxed_fused_gromov_wasserstein (
2393
+ M , C1 , C2 , p , loss_fun , epsilon = 0.1 )
0 commit comments