Skip to content

Commit 2778175

Browse files
committed
First draft : making pytest use gpu for torch testing
1 parent e1b67c6 commit 2778175

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

ot/backend.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,15 @@
4747
str_type_error = "All array should be from the same type/backend. Current types are : {}"
4848

4949

50-
def get_backend_list():
50+
def get_backend_list(test_GPU=False):
5151
"""Returns the list of available backends"""
5252
lst = [NumpyBackend(), ]
5353

5454
if torch:
5555
lst.append(TorchBackend())
56+
if test_GPU and torch.cuda.is_available():
57+
# TODO: auto activate test_gpu if a GPU is present on the machine
58+
lst.append(_TorchBackendGPU())
5659

5760
if jax:
5861
lst.append(JaxBackend())
@@ -1437,3 +1440,20 @@ def copy(self, a):
14371440

14381441
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
14391442
return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
1443+
1444+
1445+
class _TorchBackendGPU(TorchBackend):
1446+
r"""
1447+
This class allows to test the torch backend on GPUs. By wrapping the standard from_numpy method, this backend places the tensor on a GPU by default, which allow to make the test without significant changes in the code of the test itself.
1448+
"""
1449+
1450+
__name__ = TorchBackend().__name__ + ".gpu"
1451+
1452+
def __str__(self):
1453+
return super().__name__
1454+
1455+
def from_numpy(self, a, type_as=None):
1456+
tensor = super().from_numpy(a, type_as=type_as)
1457+
if type_as is None:
1458+
tensor = tensor.cuda()
1459+
return tensor

0 commit comments

Comments
 (0)