Skip to content

Commit 8a4a5a6

Browse files
framunozFrancisco Muñoz
andauthored
[MRG] Debug convolutional methods that compute barycenters to work with different devices. (#533)
* [FEAT] Add the parameter 'type_as' to the backends * [TEST] add tests for the 'type_as' backends * [DEBUG] Debug dtype in pytorch * [FIX] Add type_as every time linspace is called * [TEST] Add test for the convolutional_barycenter2d algorithms (they are the only ones use linspace) * [DEBUG] PEP 8 * [DOC] Add the new changes to RELEASES.md * [REFACTOR] Minor refactor that checks the GPU on the last line * [DEBUG] pep8 * [REFACTOR] Add a function to generalize the creation of random images * [REFACTOR] Mantain th same style as before * Update gitignore --------- Co-authored-by: Francisco Muñoz <[email protected]>
1 parent ffdd1cf commit 8a4a5a6

File tree

6 files changed

+299
-65
lines changed

6 files changed

+299
-65
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ celerybeat-schedule
9797
# virtualenv
9898
venv/
9999
ENV/
100+
.venv/
100101

101102
# Spyder project settings
102103
.spyderproject
@@ -120,4 +121,4 @@ debug
120121
.vscode
121122

122123
# pytest cahche
123-
.pytest_cache
124+
.pytest_cache

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
+ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526)
77
+ Tweaked `get_backend` to ignore `None` inputs (PR #525)
88
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
9+
+ The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533)
10+
+ The `convolutional_barycenter2d` and `convolutional_barycenter2d_debiased` functions now work with different devices.. (PR #533)
911

1012
#### Closed issues
1113
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)

ot/backend.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,15 @@
8686
#
8787
# License: MIT License
8888

89-
import numpy as np
9089
import os
91-
import scipy
92-
import scipy.linalg
93-
from scipy.sparse import issparse, coo_matrix, csr_matrix
94-
import scipy.special as special
9590
import time
9691
import warnings
9792

93+
import numpy as np
94+
import scipy
95+
import scipy.linalg
96+
import scipy.special as special
97+
from scipy.sparse import coo_matrix, csr_matrix, issparse
9898

9999
DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
100100
DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
@@ -650,7 +650,7 @@ def std(self, a, axis=None):
650650
"""
651651
raise NotImplementedError()
652652

653-
def linspace(self, start, stop, num):
653+
def linspace(self, start, stop, num, type_as=None):
654654
r"""
655655
Returns a specified number of evenly spaced values over a given interval.
656656
@@ -1208,8 +1208,11 @@ def median(self, a, axis=None):
12081208
def std(self, a, axis=None):
12091209
return np.std(a, axis=axis)
12101210

1211-
def linspace(self, start, stop, num):
1212-
return np.linspace(start, stop, num)
1211+
def linspace(self, start, stop, num, type_as=None):
1212+
if type_as is None:
1213+
return np.linspace(start, stop, num)
1214+
else:
1215+
return np.linspace(start, stop, num, dtype=type_as.dtype)
12131216

12141217
def meshgrid(self, a, b):
12151218
return np.meshgrid(a, b)
@@ -1579,8 +1582,11 @@ def median(self, a, axis=None):
15791582
def std(self, a, axis=None):
15801583
return jnp.std(a, axis=axis)
15811584

1582-
def linspace(self, start, stop, num):
1583-
return jnp.linspace(start, stop, num)
1585+
def linspace(self, start, stop, num, type_as=None):
1586+
if type_as is None:
1587+
return jnp.linspace(start, stop, num)
1588+
else:
1589+
return self._change_device(jnp.linspace(start, stop, num, dtype=type_as.dtype), type_as)
15841590

15851591
def meshgrid(self, a, b):
15861592
return jnp.meshgrid(a, b)
@@ -1986,6 +1992,7 @@ def concatenate(self, arrays, axis=0):
19861992

19871993
def zero_pad(self, a, pad_width, value=0):
19881994
from torch.nn.functional import pad
1995+
19891996
# pad_width is an array of ndim tuples indicating how many 0 before and after
19901997
# we need to add. We first need to make it compliant with torch syntax, that
19911998
# starts with the last dim, then second last, etc.
@@ -2006,6 +2013,7 @@ def mean(self, a, axis=None):
20062013

20072014
def median(self, a, axis=None):
20082015
from packaging import version
2016+
20092017
# Since version 1.11.0, interpolation is available
20102018
if version.parse(torch.__version__) >= version.parse("1.11.0"):
20112019
if axis is not None:
@@ -2026,8 +2034,11 @@ def std(self, a, axis=None):
20262034
else:
20272035
return torch.std(a, unbiased=False)
20282036

2029-
def linspace(self, start, stop, num):
2030-
return torch.linspace(start, stop, num, dtype=torch.float64)
2037+
def linspace(self, start, stop, num, type_as=None):
2038+
if type_as is None:
2039+
return torch.linspace(start, stop, num)
2040+
else:
2041+
return torch.linspace(start, stop, num, dtype=type_as.dtype, device=type_as.device)
20312042

20322043
def meshgrid(self, a, b):
20332044
try:
@@ -2427,8 +2438,12 @@ def median(self, a, axis=None):
24272438
def std(self, a, axis=None):
24282439
return cp.std(a, axis=axis)
24292440

2430-
def linspace(self, start, stop, num):
2431-
return cp.linspace(start, stop, num)
2441+
def linspace(self, start, stop, num, type_as=None):
2442+
if type_as is None:
2443+
return cp.linspace(start, stop, num)
2444+
else:
2445+
with cp.cuda.Device(type_as.device):
2446+
return cp.linspace(start, stop, num, dtype=type_as.dtype)
24322447

24332448
def meshgrid(self, a, b):
24342449
return cp.meshgrid(a, b)
@@ -2834,8 +2849,11 @@ def median(self, a, axis=None):
28342849
def std(self, a, axis=None):
28352850
return tnp.std(a, axis=axis)
28362851

2837-
def linspace(self, start, stop, num):
2838-
return tnp.linspace(start, stop, num)
2852+
def linspace(self, start, stop, num, type_as=None):
2853+
if type_as is None:
2854+
return tnp.linspace(start, stop, num)
2855+
else:
2856+
return tnp.linspace(start, stop, num, dtype=type_as.dtype)
28392857

28402858
def meshgrid(self, a, b):
28412859
return tnp.meshgrid(a, b)

ot/bregman.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import numpy as np
2121
from scipy.optimize import fmin_l_bfgs_b
2222

23-
from ot.utils import unif, dist, list_to_array
23+
from ot.utils import dist, list_to_array, unif
24+
2425
from .backend import get_backend
2526

2627

@@ -2217,11 +2218,11 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
22172218

22182219
# build the convolution operator
22192220
# this is equivalent to blurring on horizontal then vertical directions
2220-
t = nx.linspace(0, 1, A.shape[1])
2221+
t = nx.linspace(0, 1, A.shape[1], type_as=A)
22212222
[Y, X] = nx.meshgrid(t, t)
22222223
K1 = nx.exp(-(X - Y) ** 2 / reg)
22232224

2224-
t = nx.linspace(0, 1, A.shape[2])
2225+
t = nx.linspace(0, 1, A.shape[2], type_as=A)
22252226
[Y, X] = nx.meshgrid(t, t)
22262227
K2 = nx.exp(-(X - Y) ** 2 / reg)
22272228

@@ -2295,11 +2296,11 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
22952296
err = 1
22962297
# build the convolution operator
22972298
# this is equivalent to blurring on horizontal then vertical directions
2298-
t = nx.linspace(0, 1, width)
2299+
t = nx.linspace(0, 1, width, type_as=A)
22992300
[Y, X] = nx.meshgrid(t, t)
23002301
M1 = - (X - Y) ** 2 / reg
23012302

2302-
t = nx.linspace(0, 1, height)
2303+
t = nx.linspace(0, 1, height, type_as=A)
23032304
[Y, X] = nx.meshgrid(t, t)
23042305
M2 = - (X - Y) ** 2 / reg
23052306

@@ -2452,11 +2453,11 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
24522453

24532454
# build the convolution operator
24542455
# this is equivalent to blurring on horizontal then vertical directions
2455-
t = nx.linspace(0, 1, width)
2456+
t = nx.linspace(0, 1, width, type_as=A)
24562457
[Y, X] = nx.meshgrid(t, t)
24572458
K1 = nx.exp(-(X - Y) ** 2 / reg)
24582459

2459-
t = nx.linspace(0, 1, height)
2460+
t = nx.linspace(0, 1, height, type_as=A)
24602461
[Y, X] = nx.meshgrid(t, t)
24612462
K2 = nx.exp(-(X - Y) ** 2 / reg)
24622463

@@ -2532,11 +2533,11 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
25322533
err = 1
25332534
# build the convolution operator
25342535
# this is equivalent to blurring on horizontal then vertical directions
2535-
t = nx.linspace(0, 1, width)
2536+
t = nx.linspace(0, 1, width, type_as=A)
25362537
[Y, X] = nx.meshgrid(t, t)
25372538
M1 = - (X - Y) ** 2 / reg
25382539

2539-
t = nx.linspace(0, 1, height)
2540+
t = nx.linspace(0, 1, height, type_as=A)
25402541
[Y, X] = nx.meshgrid(t, t)
25412542
M2 = - (X - Y) ** 2 / reg
25422543

test/test_backend.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,13 @@
66
#
77
# License: MIT License
88

9-
import ot
10-
import ot.backend
11-
from ot.backend import torch, jax, tf
12-
13-
import pytest
14-
159
import numpy as np
10+
import pytest
1611
from numpy.testing import assert_array_almost_equal_nulp
1712

18-
from ot.backend import get_backend, get_backend_list, to_numpy
13+
import ot
14+
import ot.backend
15+
from ot.backend import get_backend, get_backend_list, jax, tf, to_numpy, torch
1916

2017

2118
def test_get_backend_list():
@@ -507,6 +504,7 @@ def test_func_backends(nx):
507504
lst_name.append('std')
508505

509506
A = nx.linspace(0, 1, 50)
507+
A = nx.linspace(0, 1, 50, type_as=Mb)
510508
lst_b.append(nx.to_numpy(A))
511509
lst_name.append('linspace')
512510

0 commit comments

Comments
 (0)