Skip to content

Commit eb7d814

Browse files
add tests
1 parent 54cdeab commit eb7d814

File tree

2 files changed

+125
-10
lines changed

2 files changed

+125
-10
lines changed

ot/gromov/_bregman.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -871,9 +871,9 @@ def entropic_fused_gromov_barycenters(
871871
else:
872872
Y = init_Y
873873

874-
T = [nx.outer(p, q) for q in ps]
874+
T = [nx.outer(p_, p) for p_ in ps]
875875

876-
Ms = [dist(Y, Ys[s]) for s in range(len(Ys))]
876+
Ms = [dist(Ys[s], Y) for s in range(len(Ys))]
877877

878878
cpt = 0
879879
err = 1
@@ -897,12 +897,12 @@ def entropic_fused_gromov_barycenters(
897897
if warmstartT:
898898
T = [entropic_fused_gromov_wasserstein(
899899
Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha,
900-
None, max_iter, 1e-4, verbose, log=False, **kwargs) for s in range(S)]
900+
None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
901901

902902
else:
903903
T = [entropic_fused_gromov_wasserstein(
904904
Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha,
905-
None, max_iter, 1e-4, verbose, log=False, **kwargs) for s in range(S)]
905+
None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
906906

907907
if loss_fun == 'square_loss':
908908
C = update_square_loss(p, lambdas, T, Cs)
@@ -911,8 +911,9 @@ def entropic_fused_gromov_barycenters(
911911
C = update_kl_loss(p, lambdas, T, Cs)
912912

913913
Ys_temp = [y.T for y in Ys]
914-
Y = update_feature_matrix(lambdas, Ys_temp, T, p).T
915-
Ms = [dist(Y, Ys[s]) for s in range(len(Ys))]
914+
T_temp = [Ts.T for Ts in T]
915+
Y = update_feature_matrix(lambdas, Ys_temp, T_temp, p)
916+
Ms = [dist(Ys[s], Y) for s in range(len(Ys))]
916917

917918
if cpt % 10 == 0:
918919
# we can speed up the process by checking for the error only all
@@ -932,7 +933,7 @@ def entropic_fused_gromov_barycenters(
932933
print('{:5d}|{:8e}|'.format(cpt, err_feature))
933934

934935
cpt += 1
935-
936+
print('Y type:', type(Y))
936937
if log:
937938
log_['T'] = T # from target to Ys
938939
log_['p'] = p

test/test_gromov.py

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,61 @@ def test_entropic_fgw_dtype_device(nx):
694694
nx.assert_same_dtype_device(C1b, fgw_valb)
695695

696696

697+
def test_entropic_fgw_barycenter(nx):
698+
ns = 5
699+
nt = 10
700+
701+
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
702+
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
703+
704+
ys = np.random.randn(Xs.shape[0], 2)
705+
yt = np.random.randn(Xt.shape[0], 2)
706+
707+
C1 = ot.dist(Xs)
708+
C2 = ot.dist(Xt)
709+
p1 = ot.unif(ns)
710+
p2 = ot.unif(nt)
711+
n_samples = 2
712+
p = ot.unif(n_samples)
713+
714+
ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)
715+
716+
X, C, log = ot.gromov.entropic_fused_gromov_barycenters(
717+
n_samples, [ys, yt], [C1, C2], [p1, p2], p, [.5, .5], 'square_loss', 0.1,
718+
max_iter=50, tol=1e-3, verbose=True, warmstartT=True, random_state=42,
719+
solver='PPA', numItermax=1, log=True
720+
)
721+
Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters(
722+
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'square_loss', 0.1,
723+
max_iter=50, tol=1e-3, verbose=False, warmstartT=True, random_state=42,
724+
solver='PPA', numItermax=1, log=False)
725+
Xb, Cb = nx.to_numpy(Xb, Cb)
726+
727+
np.testing.assert_allclose(C, Cb, atol=1e-06)
728+
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
729+
np.testing.assert_allclose(X, Xb, atol=1e-06)
730+
np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
731+
732+
# test with 'kl_loss' and log=True
733+
X, C, log = ot.gromov.entropic_fused_gromov_barycenters(
734+
n_samples, [ys, yt], [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', 0.1,
735+
max_iter=50, tol=1e-3, verbose=False, warmstartT=False, random_state=42,
736+
solver='PPA', numItermax=1, log=True
737+
)
738+
Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters(
739+
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', 0.1,
740+
max_iter=50, tol=1e-3, verbose=False, warmstartT=False, random_state=42,
741+
solver='PPA', numItermax=1, log=True)
742+
Xb, Cb = nx.to_numpy(Xb, Cb)
743+
744+
np.testing.assert_allclose(C, Cb, atol=1e-06)
745+
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
746+
np.testing.assert_allclose(X, Xb, atol=1e-06)
747+
np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
748+
np.testing.assert_array_almost_equal(log['err_feature'], nx.to_numpy(*logb['err_feature']))
749+
np.testing.assert_array_almost_equal(log['err_structure'], nx.to_numpy(*logb['err_structure']))
750+
751+
697752
def test_pointwise_gromov(nx):
698753
n_samples = 5 # nb samples
699754

@@ -1173,6 +1228,9 @@ def test_fgw_barycenter(nx):
11731228

11741229
C1 = ot.dist(Xs)
11751230
C2 = ot.dist(Xt)
1231+
C1 /= C1.max()
1232+
C2 /= C2.max()
1233+
11761234
p1, p2 = ot.unif(ns), ot.unif(nt)
11771235
n_samples = 3
11781236
p = ot.unif(n_samples)
@@ -1186,6 +1244,7 @@ def test_fgw_barycenter(nx):
11861244

11871245
xalea = np.random.randn(n_samples, 2)
11881246
init_C = ot.dist(xalea, xalea)
1247+
init_C /= init_C.max()
11891248
init_Cb = nx.from_numpy(init_C)
11901249

11911250
Xb, Cb = ot.gromov.fgw_barycenters(
@@ -1206,9 +1265,18 @@ def test_fgw_barycenter(nx):
12061265
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
12071266
warmstartT=True, log=True, random_state=98765
12081267
)
1209-
Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
1210-
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
1211-
np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
1268+
X, C = nx.to_numpy(Xb), nx.to_numpy(Cb)
1269+
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
1270+
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
1271+
1272+
# add test with 'kl_loss'
1273+
X, C = ot.gromov.fgw_barycenters(
1274+
n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5,
1275+
fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss',
1276+
max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True, random_state=12345
1277+
)
1278+
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
1279+
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
12121280

12131281

12141282
def test_gromov_wasserstein_linear_unmixing(nx):
@@ -2277,3 +2345,49 @@ def test_entropic_semirelaxed_fgw_dtype_device(nx):
22772345

22782346
nx.assert_same_dtype_device(C1b, Gb)
22792347
nx.assert_same_dtype_device(C1b, fgw_valb)
2348+
2349+
2350+
def test_not_implemented_solver():
2351+
# test sinkhorn
2352+
n_samples = 5 # nb samples
2353+
mu_s = np.array([0, 0])
2354+
cov_s = np.array([[1, 0], [0, 1]])
2355+
2356+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
2357+
xt = xs[::-1].copy()
2358+
2359+
ys = np.random.randn(xs.shape[0], 2)
2360+
yt = ys[::-1].copy()
2361+
2362+
p = ot.unif(n_samples)
2363+
q = ot.unif(n_samples)
2364+
2365+
C1 = ot.dist(xs, xs)
2366+
C2 = ot.dist(xt, xt)
2367+
2368+
C1 /= C1.max()
2369+
C2 /= C2.max()
2370+
M = ot.dist(ys, yt)
2371+
2372+
solver = 'not_implemented'
2373+
# entropic gw and fgw
2374+
with pytest.raises(ValueError):
2375+
ot.gromov.entropic_gromov_wasserstein(
2376+
C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver)
2377+
with pytest.raises(ValueError):
2378+
ot.gromov.entropic_fused_gromov_wasserstein(
2379+
M, C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver)
2380+
2381+
# exact and entropic srgw and srfgw loss functions
2382+
loss_fun = 'kl_loss'
2383+
with pytest.raises(NotImplementedError):
2384+
ot.gromov.semirelaxed_gromov_wasserstein(
2385+
C1, C2, p, loss_fun, armijo=False)
2386+
with pytest.raises(NotImplementedError):
2387+
ot.gromov.entropic_semirelaxed_gromov_wasserstein(
2388+
C1, C2, p, loss_fun, epsilon=0.1)
2389+
with pytest.raises(NotImplementedError):
2390+
ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun)
2391+
with pytest.raises(NotImplementedError):
2392+
ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(
2393+
M, C1, C2, p, loss_fun, epsilon=0.1)

0 commit comments

Comments
 (0)