Skip to content

Commit 91c67fb

Browse files
authored
[MRG] Avoid changing precision in the backend (#572)
* Avoid changing precision in the backend * Update RELEASES.md
1 parent a56e1b2 commit 91c67fb

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
2323
- Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559)
2424
- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566)
25+
- Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572)
2526

2627
## 0.9.1
2728
*August 2023*

ot/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1941,7 +1941,7 @@ def power(self, a, exponents):
19411941
return torch.pow(a, exponents)
19421942

19431943
def norm(self, a, axis=None, keepdims=False):
1944-
return torch.linalg.norm(a.double(), dim=axis, keepdims=keepdims)
1944+
return torch.linalg.norm(a, dim=axis, keepdims=keepdims)
19451945

19461946
def any(self, a):
19471947
return torch.any(a)

test/test_gaussian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_empirical_bures_wasserstein_mapping_numerical_error_warning():
8181

8282

8383
def test_bures_wasserstein_distance(nx):
84-
ms, mt = np.array([0]), np.array([10])
84+
ms, mt = np.array([0]).astype(np.float32), np.array([10]).astype(np.float32)
8585
Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32)
8686
msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct)
8787
Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True)

0 commit comments

Comments
 (0)