8686#
8787# License: MIT License
8888
89- import numpy as np
9089import os
91- import scipy
92- import scipy .linalg
93- from scipy .sparse import issparse , coo_matrix , csr_matrix
94- import scipy .special as special
9590import time
9691import warnings
9792
93+ import numpy as np
94+ import scipy
95+ import scipy .linalg
96+ import scipy .special as special
97+ from scipy .sparse import coo_matrix , csr_matrix , issparse
9898
9999DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
100100DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
@@ -650,7 +650,7 @@ def std(self, a, axis=None):
650650 """
651651 raise NotImplementedError ()
652652
653- def linspace (self , start , stop , num ):
653+ def linspace (self , start , stop , num , type_as = None ):
654654 r"""
655655 Returns a specified number of evenly spaced values over a given interval.
656656
@@ -1208,8 +1208,11 @@ def median(self, a, axis=None):
12081208 def std (self , a , axis = None ):
12091209 return np .std (a , axis = axis )
12101210
1211- def linspace (self , start , stop , num ):
1212- return np .linspace (start , stop , num )
1211+ def linspace (self , start , stop , num , type_as = None ):
1212+ if type_as is None :
1213+ return np .linspace (start , stop , num )
1214+ else :
1215+ return np .linspace (start , stop , num , dtype = type_as .dtype )
12131216
12141217 def meshgrid (self , a , b ):
12151218 return np .meshgrid (a , b )
@@ -1579,8 +1582,11 @@ def median(self, a, axis=None):
15791582 def std (self , a , axis = None ):
15801583 return jnp .std (a , axis = axis )
15811584
1582- def linspace (self , start , stop , num ):
1583- return jnp .linspace (start , stop , num )
1585+ def linspace (self , start , stop , num , type_as = None ):
1586+ if type_as is None :
1587+ return jnp .linspace (start , stop , num )
1588+ else :
1589+ return self ._change_device (jnp .linspace (start , stop , num , dtype = type_as .dtype ), type_as )
15841590
15851591 def meshgrid (self , a , b ):
15861592 return jnp .meshgrid (a , b )
@@ -1986,6 +1992,7 @@ def concatenate(self, arrays, axis=0):
19861992
19871993 def zero_pad (self , a , pad_width , value = 0 ):
19881994 from torch .nn .functional import pad
1995+
19891996 # pad_width is an array of ndim tuples indicating how many 0 before and after
19901997 # we need to add. We first need to make it compliant with torch syntax, that
19911998 # starts with the last dim, then second last, etc.
@@ -2006,6 +2013,7 @@ def mean(self, a, axis=None):
20062013
20072014 def median (self , a , axis = None ):
20082015 from packaging import version
2016+
20092017 # Since version 1.11.0, interpolation is available
20102018 if version .parse (torch .__version__ ) >= version .parse ("1.11.0" ):
20112019 if axis is not None :
@@ -2026,8 +2034,11 @@ def std(self, a, axis=None):
20262034 else :
20272035 return torch .std (a , unbiased = False )
20282036
2029- def linspace (self , start , stop , num ):
2030- return torch .linspace (start , stop , num , dtype = torch .float64 )
2037+ def linspace (self , start , stop , num , type_as = None ):
2038+ if type_as is None :
2039+ return torch .linspace (start , stop , num )
2040+ else :
2041+ return torch .linspace (start , stop , num , dtype = type_as .dtype , device = type_as .device )
20312042
20322043 def meshgrid (self , a , b ):
20332044 try :
@@ -2427,8 +2438,12 @@ def median(self, a, axis=None):
24272438 def std (self , a , axis = None ):
24282439 return cp .std (a , axis = axis )
24292440
2430- def linspace (self , start , stop , num ):
2431- return cp .linspace (start , stop , num )
2441+ def linspace (self , start , stop , num , type_as = None ):
2442+ if type_as is None :
2443+ return cp .linspace (start , stop , num )
2444+ else :
2445+ with cp .cuda .Device (type_as .device ):
2446+ return cp .linspace (start , stop , num , dtype = type_as .dtype )
24322447
24332448 def meshgrid (self , a , b ):
24342449 return cp .meshgrid (a , b )
@@ -2834,8 +2849,11 @@ def median(self, a, axis=None):
28342849 def std (self , a , axis = None ):
28352850 return tnp .std (a , axis = axis )
28362851
2837- def linspace (self , start , stop , num ):
2838- return tnp .linspace (start , stop , num )
2852+ def linspace (self , start , stop , num , type_as = None ):
2853+ if type_as is None :
2854+ return tnp .linspace (start , stop , num )
2855+ else :
2856+ return tnp .linspace (start , stop , num , dtype = type_as .dtype )
28392857
28402858 def meshgrid (self , a , b ):
28412859 return tnp .meshgrid (a , b )
0 commit comments