@@ -233,15 +233,15 @@ def _check_allowed_dtypes(
233233
234234 return other
235235
236- def _check_device (self , other : Array | bool | int | float | complex ) -> None :
237- """Check that other is on a device compatible with the current array"""
238- if isinstance ( other , ( bool , int , float , complex )):
239- return
240- elif isinstance (other , Array ):
236+ def _check_type_device (self , other : Array | bool | int | float | complex ) -> None :
237+ """Check that other is either a Python scalar or an array on a device
238+ compatible with the current array.
239+ """
240+ if isinstance (other , Array ):
241241 if self .device != other .device :
242242 raise ValueError (f"Arrays from two different devices ({ self .device } and { other .device } ) can not be combined." )
243- else :
244- raise TypeError (f"Expected Array | python scalar; got { type (other )} " )
243+ elif not isinstance ( other , bool | int | float | complex ) :
244+ raise TypeError (f"Expected Array or Python scalar; got { type (other )} " )
245245
246246 # Helper function to match the type promotion rules in the spec
247247 def _promote_scalar (self , scalar : bool | int | float | complex ) -> Array :
@@ -542,7 +542,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array:
542542 """
543543 Performs the operation __add__.
544544 """
545- self ._check_device (other )
545+ self ._check_type_device (other )
546546 other = self ._check_allowed_dtypes (other , "numeric" , "__add__" )
547547 if other is NotImplemented :
548548 return other
@@ -554,7 +554,7 @@ def __and__(self, other: Array | bool | int, /) -> Array:
554554 """
555555 Performs the operation __and__.
556556 """
557- self ._check_device (other )
557+ self ._check_type_device (other )
558558 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__and__" )
559559 if other is NotImplemented :
560560 return other
@@ -651,7 +651,7 @@ def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
651651 """
652652 Performs the operation __eq__.
653653 """
654- self ._check_device (other )
654+ self ._check_type_device (other )
655655 # Even though "all" dtypes are allowed, we still require them to be
656656 # promotable with each other.
657657 other = self ._check_allowed_dtypes (other , "all" , "__eq__" )
@@ -677,7 +677,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array:
677677 """
678678 Performs the operation __floordiv__.
679679 """
680- self ._check_device (other )
680+ self ._check_type_device (other )
681681 other = self ._check_allowed_dtypes (other , "real numeric" , "__floordiv__" )
682682 if other is NotImplemented :
683683 return other
@@ -689,7 +689,7 @@ def __ge__(self, other: Array | int | float, /) -> Array:
689689 """
690690 Performs the operation __ge__.
691691 """
692- self ._check_device (other )
692+ self ._check_type_device (other )
693693 other = self ._check_allowed_dtypes (other , "real numeric" , "__ge__" )
694694 if other is NotImplemented :
695695 return other
@@ -741,7 +741,7 @@ def __gt__(self, other: Array | int | float, /) -> Array:
741741 """
742742 Performs the operation __gt__.
743743 """
744- self ._check_device (other )
744+ self ._check_type_device (other )
745745 other = self ._check_allowed_dtypes (other , "real numeric" , "__gt__" )
746746 if other is NotImplemented :
747747 return other
@@ -796,7 +796,7 @@ def __le__(self, other: Array | int | float, /) -> Array:
796796 """
797797 Performs the operation __le__.
798798 """
799- self ._check_device (other )
799+ self ._check_type_device (other )
800800 other = self ._check_allowed_dtypes (other , "real numeric" , "__le__" )
801801 if other is NotImplemented :
802802 return other
@@ -808,7 +808,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
808808 """
809809 Performs the operation __lshift__.
810810 """
811- self ._check_device (other )
811+ self ._check_type_device (other )
812812 other = self ._check_allowed_dtypes (other , "integer" , "__lshift__" )
813813 if other is NotImplemented :
814814 return other
@@ -820,7 +820,7 @@ def __lt__(self, other: Array | int | float, /) -> Array:
820820 """
821821 Performs the operation __lt__.
822822 """
823- self ._check_device (other )
823+ self ._check_type_device (other )
824824 other = self ._check_allowed_dtypes (other , "real numeric" , "__lt__" )
825825 if other is NotImplemented :
826826 return other
@@ -832,7 +832,7 @@ def __matmul__(self, other: Array, /) -> Array:
832832 """
833833 Performs the operation __matmul__.
834834 """
835- self ._check_device (other )
835+ self ._check_type_device (other )
836836 # matmul is not defined for scalars, but without this, we may get
837837 # the wrong error message from asarray.
838838 other = self ._check_allowed_dtypes (other , "numeric" , "__matmul__" )
@@ -845,7 +845,7 @@ def __mod__(self, other: Array | int | float, /) -> Array:
845845 """
846846 Performs the operation __mod__.
847847 """
848- self ._check_device (other )
848+ self ._check_type_device (other )
849849 other = self ._check_allowed_dtypes (other , "real numeric" , "__mod__" )
850850 if other is NotImplemented :
851851 return other
@@ -857,7 +857,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array:
857857 """
858858 Performs the operation __mul__.
859859 """
860- self ._check_device (other )
860+ self ._check_type_device (other )
861861 other = self ._check_allowed_dtypes (other , "numeric" , "__mul__" )
862862 if other is NotImplemented :
863863 return other
@@ -869,7 +869,7 @@ def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # ty
869869 """
870870 Performs the operation __ne__.
871871 """
872- self ._check_device (other )
872+ self ._check_type_device (other )
873873 other = self ._check_allowed_dtypes (other , "all" , "__ne__" )
874874 if other is NotImplemented :
875875 return other
@@ -890,7 +890,7 @@ def __or__(self, other: Array | bool | int, /) -> Array:
890890 """
891891 Performs the operation __or__.
892892 """
893- self ._check_device (other )
893+ self ._check_type_device (other )
894894 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__or__" )
895895 if other is NotImplemented :
896896 return other
@@ -913,7 +913,7 @@ def __pow__(self, other: Array | int | float | complex, /) -> Array:
913913 """
914914 from ._elementwise_functions import pow # type: ignore[attr-defined]
915915
916- self ._check_device (other )
916+ self ._check_type_device (other )
917917 other = self ._check_allowed_dtypes (other , "numeric" , "__pow__" )
918918 if other is NotImplemented :
919919 return other
@@ -925,7 +925,7 @@ def __rshift__(self, other: Array | int, /) -> Array:
925925 """
926926 Performs the operation __rshift__.
927927 """
928- self ._check_device (other )
928+ self ._check_type_device (other )
929929 other = self ._check_allowed_dtypes (other , "integer" , "__rshift__" )
930930 if other is NotImplemented :
931931 return other
@@ -961,7 +961,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array:
961961 """
962962 Performs the operation __sub__.
963963 """
964- self ._check_device (other )
964+ self ._check_type_device (other )
965965 other = self ._check_allowed_dtypes (other , "numeric" , "__sub__" )
966966 if other is NotImplemented :
967967 return other
@@ -975,7 +975,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array:
975975 """
976976 Performs the operation __truediv__.
977977 """
978- self ._check_device (other )
978+ self ._check_type_device (other )
979979 other = self ._check_allowed_dtypes (other , "floating-point" , "__truediv__" )
980980 if other is NotImplemented :
981981 return other
@@ -987,7 +987,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array:
987987 """
988988 Performs the operation __xor__.
989989 """
990- self ._check_device (other )
990+ self ._check_type_device (other )
991991 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__xor__" )
992992 if other is NotImplemented :
993993 return other
@@ -999,7 +999,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array:
999999 """
10001000 Performs the operation __iadd__.
10011001 """
1002- self ._check_device (other )
1002+ self ._check_type_device (other )
10031003 other = self ._check_allowed_dtypes (other , "numeric" , "__iadd__" )
10041004 if other is NotImplemented :
10051005 return other
@@ -1010,7 +1010,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array:
10101010 """
10111011 Performs the operation __radd__.
10121012 """
1013- self ._check_device (other )
1013+ self ._check_type_device (other )
10141014 other = self ._check_allowed_dtypes (other , "numeric" , "__radd__" )
10151015 if other is NotImplemented :
10161016 return other
@@ -1022,7 +1022,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array:
10221022 """
10231023 Performs the operation __iand__.
10241024 """
1025- self ._check_device (other )
1025+ self ._check_type_device (other )
10261026 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__iand__" )
10271027 if other is NotImplemented :
10281028 return other
@@ -1033,7 +1033,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array:
10331033 """
10341034 Performs the operation __rand__.
10351035 """
1036- self ._check_device (other )
1036+ self ._check_type_device (other )
10371037 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__rand__" )
10381038 if other is NotImplemented :
10391039 return other
@@ -1045,7 +1045,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array:
10451045 """
10461046 Performs the operation __ifloordiv__.
10471047 """
1048- self ._check_device (other )
1048+ self ._check_type_device (other )
10491049 other = self ._check_allowed_dtypes (other , "real numeric" , "__ifloordiv__" )
10501050 if other is NotImplemented :
10511051 return other
@@ -1056,7 +1056,7 @@ def __rfloordiv__(self, other: Array | int | float, /) -> Array:
10561056 """
10571057 Performs the operation __rfloordiv__.
10581058 """
1059- self ._check_device (other )
1059+ self ._check_type_device (other )
10601060 other = self ._check_allowed_dtypes (other , "real numeric" , "__rfloordiv__" )
10611061 if other is NotImplemented :
10621062 return other
@@ -1068,7 +1068,7 @@ def __ilshift__(self, other: Array | int, /) -> Array:
10681068 """
10691069 Performs the operation __ilshift__.
10701070 """
1071- self ._check_device (other )
1071+ self ._check_type_device (other )
10721072 other = self ._check_allowed_dtypes (other , "integer" , "__ilshift__" )
10731073 if other is NotImplemented :
10741074 return other
@@ -1079,7 +1079,7 @@ def __rlshift__(self, other: Array | int, /) -> Array:
10791079 """
10801080 Performs the operation __rlshift__.
10811081 """
1082- self ._check_device (other )
1082+ self ._check_type_device (other )
10831083 other = self ._check_allowed_dtypes (other , "integer" , "__rlshift__" )
10841084 if other is NotImplemented :
10851085 return other
@@ -1096,7 +1096,7 @@ def __imatmul__(self, other: Array, /) -> Array:
10961096 other = self ._check_allowed_dtypes (other , "numeric" , "__imatmul__" )
10971097 if other is NotImplemented :
10981098 return other
1099- self ._check_device (other )
1099+ self ._check_type_device (other )
11001100 res = self ._array .__imatmul__ (other ._array )
11011101 return self .__class__ ._new (res , device = self .device )
11021102
@@ -1109,7 +1109,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
11091109 other = self ._check_allowed_dtypes (other , "numeric" , "__rmatmul__" )
11101110 if other is NotImplemented :
11111111 return other
1112- self ._check_device (other )
1112+ self ._check_type_device (other )
11131113 res = self ._array .__rmatmul__ (other ._array )
11141114 return self .__class__ ._new (res , device = self .device )
11151115
@@ -1130,7 +1130,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array:
11301130 other = self ._check_allowed_dtypes (other , "real numeric" , "__rmod__" )
11311131 if other is NotImplemented :
11321132 return other
1133- self ._check_device (other )
1133+ self ._check_type_device (other )
11341134 self , other = self ._normalize_two_args (self , other )
11351135 res = self ._array .__rmod__ (other ._array )
11361136 return self .__class__ ._new (res , device = self .device )
@@ -1152,7 +1152,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array:
11521152 other = self ._check_allowed_dtypes (other , "numeric" , "__rmul__" )
11531153 if other is NotImplemented :
11541154 return other
1155- self ._check_device (other )
1155+ self ._check_type_device (other )
11561156 self , other = self ._normalize_two_args (self , other )
11571157 res = self ._array .__rmul__ (other ._array )
11581158 return self .__class__ ._new (res , device = self .device )
@@ -1171,7 +1171,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array:
11711171 """
11721172 Performs the operation __ror__.
11731173 """
1174- self ._check_device (other )
1174+ self ._check_type_device (other )
11751175 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__ror__" )
11761176 if other is NotImplemented :
11771177 return other
@@ -1219,7 +1219,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
12191219 other = self ._check_allowed_dtypes (other , "integer" , "__rrshift__" )
12201220 if other is NotImplemented :
12211221 return other
1222- self ._check_device (other )
1222+ self ._check_type_device (other )
12231223 self , other = self ._normalize_two_args (self , other )
12241224 res = self ._array .__rrshift__ (other ._array )
12251225 return self .__class__ ._new (res , device = self .device )
@@ -1241,7 +1241,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array:
12411241 other = self ._check_allowed_dtypes (other , "numeric" , "__rsub__" )
12421242 if other is NotImplemented :
12431243 return other
1244- self ._check_device (other )
1244+ self ._check_type_device (other )
12451245 self , other = self ._normalize_two_args (self , other )
12461246 res = self ._array .__rsub__ (other ._array )
12471247 return self .__class__ ._new (res , device = self .device )
@@ -1263,7 +1263,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
12631263 other = self ._check_allowed_dtypes (other , "floating-point" , "__rtruediv__" )
12641264 if other is NotImplemented :
12651265 return other
1266- self ._check_device (other )
1266+ self ._check_type_device (other )
12671267 self , other = self ._normalize_two_args (self , other )
12681268 res = self ._array .__rtruediv__ (other ._array )
12691269 return self .__class__ ._new (res , device = self .device )
@@ -1285,7 +1285,7 @@ def __rxor__(self, other: Array | bool | int, /) -> Array:
12851285 other = self ._check_allowed_dtypes (other , "integer or boolean" , "__rxor__" )
12861286 if other is NotImplemented :
12871287 return other
1288- self ._check_device (other )
1288+ self ._check_type_device (other )
12891289 self , other = self ._normalize_two_args (self , other )
12901290 res = self ._array .__rxor__ (other ._array )
12911291 return self .__class__ ._new (res , device = self .device )
0 commit comments