@@ -653,6 +653,18 @@ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
653
653
"""
654
654
raise NotImplementedError ()
655
655
656
+ def dtype_device (self , a ):
657
+ r"""
658
+ Returns the dtype and the device of the given tensor.
659
+ """
660
+ raise NotImplementedError ()
661
+
662
+ def assert_same_dtype_device (self , a , b ):
663
+ r"""
664
+ Checks whether or not the two given inputs have the same dtype as well as the same device
665
+ """
666
+ raise NotImplementedError ()
667
+
656
668
657
669
class NumpyBackend (Backend ):
658
670
"""
@@ -880,6 +892,16 @@ def copy(self, a):
880
892
def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
881
893
return np .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
882
894
895
+ def dtype_device (self , a ):
896
+ if hasattr (a , "dtype" ):
897
+ return a .dtype , "cpu"
898
+ else :
899
+ return type (a ), "cpu"
900
+
901
+ def assert_same_dtype_device (self , a , b ):
902
+ # numpy has implicit type conversion so we automatically validate the test
903
+ pass
904
+
883
905
884
906
class JaxBackend (Backend ):
885
907
"""
@@ -899,17 +921,20 @@ def __init__(self):
899
921
self .rng_ = jax .random .PRNGKey (42 )
900
922
901
923
for d in jax .devices ():
902
- self .__type_list__ = [jax .device_put (jnp .array (1 , dtype = np .float32 ), d ),
903
- jax .device_put (jnp .array (1 , dtype = np .float64 ), d )]
924
+ self .__type_list__ = [jax .device_put (jnp .array (1 , dtype = jnp .float32 ), d ),
925
+ jax .device_put (jnp .array (1 , dtype = jnp .float64 ), d )]
904
926
905
927
def to_numpy (self , a ):
906
928
return np .array (a )
907
929
930
+ def _change_device (self , a , type_as ):
931
+ return jax .device_put (a , type_as .device_buffer .device ())
932
+
908
933
def from_numpy (self , a , type_as = None ):
909
934
if type_as is None :
910
935
return jnp .array (a )
911
936
else :
912
- return jax . device_put (jnp .array (a ).astype (type_as .dtype ), type_as . device_buffer . device () )
937
+ return self . _change_device (jnp .array (a ).astype (type_as .dtype ), type_as )
913
938
914
939
def set_gradients (self , val , inputs , grads ):
915
940
from jax .flatten_util import ravel_pytree
@@ -928,13 +953,13 @@ def zeros(self, shape, type_as=None):
928
953
if type_as is None :
929
954
return jnp .zeros (shape )
930
955
else :
931
- return jnp .zeros (shape , dtype = type_as .dtype )
956
+ return self . _change_device ( jnp .zeros (shape , dtype = type_as .dtype ), type_as )
932
957
933
958
def ones (self , shape , type_as = None ):
934
959
if type_as is None :
935
960
return jnp .ones (shape )
936
961
else :
937
- return jnp .ones (shape , dtype = type_as .dtype )
962
+ return self . _change_device ( jnp .ones (shape , dtype = type_as .dtype ), type_as )
938
963
939
964
def arange (self , stop , start = 0 , step = 1 , type_as = None ):
940
965
return jnp .arange (start , stop , step )
@@ -943,13 +968,13 @@ def full(self, shape, fill_value, type_as=None):
943
968
if type_as is None :
944
969
return jnp .full (shape , fill_value )
945
970
else :
946
- return jnp .full (shape , fill_value , dtype = type_as .dtype )
971
+ return self . _change_device ( jnp .full (shape , fill_value , dtype = type_as .dtype ), type_as )
947
972
948
973
def eye (self , N , M = None , type_as = None ):
949
974
if type_as is None :
950
975
return jnp .eye (N , M )
951
976
else :
952
- return jnp .eye (N , M , dtype = type_as .dtype )
977
+ return self . _change_device ( jnp .eye (N , M , dtype = type_as .dtype ), type_as )
953
978
954
979
def sum (self , a , axis = None , keepdims = False ):
955
980
return jnp .sum (a , axis , keepdims = keepdims )
@@ -1127,6 +1152,16 @@ def copy(self, a):
1127
1152
def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
1128
1153
return jnp .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
1129
1154
1155
+ def dtype_device (self , a ):
1156
+ return a .dtype , a .device_buffer .device ()
1157
+
1158
+ def assert_same_dtype_device (self , a , b ):
1159
+ a_dtype , a_device = self .dtype_device (a )
1160
+ b_dtype , b_device = self .dtype_device (b )
1161
+
1162
+ assert a_dtype == b_dtype , "Dtype discrepancy"
1163
+ assert a_device == b_device , f"Device discrepancy. First input is on { str (a_device )} , whereas second input is on { str (b_device )} "
1164
+
1130
1165
1131
1166
class TorchBackend (Backend ):
1132
1167
"""
@@ -1455,3 +1490,13 @@ def copy(self, a):
1455
1490
1456
1491
def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
1457
1492
return torch .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
1493
+
1494
+ def dtype_device (self , a ):
1495
+ return a .dtype , a .device
1496
+
1497
+ def assert_same_dtype_device (self , a , b ):
1498
+ a_dtype , a_device = self .dtype_device (a )
1499
+ b_dtype , b_device = self .dtype_device (b )
1500
+
1501
+ assert a_dtype == b_dtype , "Dtype discrepancy"
1502
+ assert a_device == b_device , f"Device discrepancy. First input is on { str (a_device )} , whereas second input is on { str (b_device )} "
0 commit comments