Skip to content

Commit fcce951

Browse files
authored
Merge branch 'master' into patch-1
2 parents 328e03a + f395e58 commit fcce951

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

RELEASES.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#### Closed issues
66
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
7-
7+
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
88

99
## 0.9.2
1010
*December 2023*
@@ -671,4 +671,4 @@ It provides the following solvers:
671671
* Optimal transport for domain adaptation with group lasso regularization
672672
* Conditional gradient and Generalized conditional gradient for regularized OT.
673673

674-
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
674+
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.

ot/gromov/_semirelaxed.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
114114
else:
115115
q = nx.sum(G0, 0)
116116
# Check first marginal of G0
117-
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
117+
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)
118118

119119
constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
120120

@@ -363,8 +363,8 @@ def semirelaxed_fused_gromov_wasserstein(
363363
G0 = nx.outer(p, q)
364364
else:
365365
q = nx.sum(G0, 0)
366-
# Check marginals of G0
367-
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
366+
# Check first marginal of G0
367+
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)
368368

369369
constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
370370

@@ -703,7 +703,7 @@ def entropic_semirelaxed_gromov_wasserstein(
703703
else:
704704
q = nx.sum(G0, 0)
705705
# Check first marginal of G0
706-
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
706+
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)
707707

708708
constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
709709

@@ -951,7 +951,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
951951
else:
952952
q = nx.sum(G0, 0)
953953
# Check first marginal of G0
954-
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
954+
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)
955955

956956
constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
957957

0 commit comments

Comments
 (0)