@@ -199,7 +199,7 @@ def __array__(
199199 # NumPy behavior
200200
201201 def _check_allowed_dtypes (
202- self , other : Array | complex , dtype_category : str , op : str
202+ self , other : Array | bool | int | float | complex , dtype_category : str , op : str
203203 ) -> Array :
204204 """
205205 Helper function for operators to only allow specific input dtypes
@@ -241,7 +241,7 @@ def _check_allowed_dtypes(
241241
242242 return other
243243
244- def _check_device (self , other : Array | complex ) -> None :
244+ def _check_device (self , other : Array | bool | int | float | complex ) -> None :
245245 """Check that other is on a device compatible with the current array"""
246246 if isinstance (other , (bool , int , float , complex )):
247247 return
@@ -252,7 +252,7 @@ def _check_device(self, other: Array | complex) -> None:
252252 raise TypeError (f"Expected Array | python scalar; got { type (other )} " )
253253
254254 # Helper function to match the type promotion rules in the spec
255- def _promote_scalar (self , scalar : complex ) -> Array :
255+ def _promote_scalar (self , scalar : bool | int | float | complex ) -> Array :
256256 """
257257 Returns a promoted version of a Python scalar appropriate for use with
258258 operations on self.
@@ -546,7 +546,7 @@ def __abs__(self) -> Array:
546546 res = self ._array .__abs__ ()
547547 return self .__class__ ._new (res , device = self .device )
548548
549- def __add__ (self , other : Array | complex , / ) -> Array :
549+ def __add__ (self , other : Array | int | float | complex , / ) -> Array :
550550 """
551551 Performs the operation __add__.
552552 """
@@ -558,7 +558,7 @@ def __add__(self, other: Array | complex, /) -> Array:
558558 res = self ._array .__add__ (other ._array )
559559 return self .__class__ ._new (res , device = self .device )
560560
561- def __and__ (self , other : Array | int , / ) -> Array :
561+ def __and__ (self , other : Array | bool | int , / ) -> Array :
562562 """
563563 Performs the operation __and__.
564564 """
@@ -655,7 +655,7 @@ def __dlpack_device__(self) -> tuple[IntEnum, int]:
655655 # Note: device support is required for this
656656 return self ._array .__dlpack_device__ ()
657657
658- def __eq__ (self , other : Array | complex , / ) -> Array : # type: ignore[override]
658+ def __eq__ (self , other : Array | bool | int | float | complex , / ) -> Array : # type: ignore[override]
659659 """
660660 Performs the operation __eq__.
661661 """
@@ -681,7 +681,7 @@ def __float__(self) -> float:
681681 res = self ._array .__float__ ()
682682 return res
683683
684- def __floordiv__ (self , other : Array | float , / ) -> Array :
684+ def __floordiv__ (self , other : Array | int | float , / ) -> Array :
685685 """
686686 Performs the operation __floordiv__.
687687 """
@@ -693,7 +693,7 @@ def __floordiv__(self, other: Array | float, /) -> Array:
693693 res = self ._array .__floordiv__ (other ._array )
694694 return self .__class__ ._new (res , device = self .device )
695695
696- def __ge__ (self , other : Array | float , / ) -> Array :
696+ def __ge__ (self , other : Array | int | float , / ) -> Array :
697697 """
698698 Performs the operation __ge__.
699699 """
@@ -728,7 +728,7 @@ def __getitem__(
728728 res = self ._array .__getitem__ (np_key )
729729 return self ._new (res , device = self .device )
730730
731- def __gt__ (self , other : Array | float , / ) -> Array :
731+ def __gt__ (self , other : Array | int | float , / ) -> Array :
732732 """
733733 Performs the operation __gt__.
734734 """
@@ -783,7 +783,7 @@ def __iter__(self) -> Iterator[Array]:
783783 # implemented, which implies iteration on 1-D arrays.
784784 return (Array ._new (i , device = self .device ) for i in self ._array )
785785
786- def __le__ (self , other : Array | float , / ) -> Array :
786+ def __le__ (self , other : Array | int | float , / ) -> Array :
787787 """
788788 Performs the operation __le__.
789789 """
@@ -807,7 +807,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
807807 res = self ._array .__lshift__ (other ._array )
808808 return self .__class__ ._new (res , device = self .device )
809809
810- def __lt__ (self , other : Array | float , / ) -> Array :
810+ def __lt__ (self , other : Array | int | float , / ) -> Array :
811811 """
812812 Performs the operation __lt__.
813813 """
@@ -832,7 +832,7 @@ def __matmul__(self, other: Array, /) -> Array:
832832 res = self ._array .__matmul__ (other ._array )
833833 return self .__class__ ._new (res , device = self .device )
834834
835- def __mod__ (self , other : Array | float , / ) -> Array :
835+ def __mod__ (self , other : Array | int | float , / ) -> Array :
836836 """
837837 Performs the operation __mod__.
838838 """
@@ -844,7 +844,7 @@ def __mod__(self, other: Array | float, /) -> Array:
844844 res = self ._array .__mod__ (other ._array )
845845 return self .__class__ ._new (res , device = self .device )
846846
847- def __mul__ (self , other : Array | complex , / ) -> Array :
847+ def __mul__ (self , other : Array | int | float | complex , / ) -> Array :
848848 """
849849 Performs the operation __mul__.
850850 """
@@ -856,7 +856,7 @@ def __mul__(self, other: Array | complex, /) -> Array:
856856 res = self ._array .__mul__ (other ._array )
857857 return self .__class__ ._new (res , device = self .device )
858858
859- def __ne__ (self , other : Array | complex , / ) -> Array : # type: ignore[override]
859+ def __ne__ (self , other : Array | bool | int | float | complex , / ) -> Array : # type: ignore[override]
860860 """
861861 Performs the operation __ne__.
862862 """
@@ -877,7 +877,7 @@ def __neg__(self) -> Array:
877877 res = self ._array .__neg__ ()
878878 return self .__class__ ._new (res , device = self .device )
879879
880- def __or__ (self , other : Array | int , / ) -> Array :
880+ def __or__ (self , other : Array | bool | int , / ) -> Array :
881881 """
882882 Performs the operation __or__.
883883 """
@@ -898,7 +898,7 @@ def __pos__(self) -> Array:
898898 res = self ._array .__pos__ ()
899899 return self .__class__ ._new (res , device = self .device )
900900
901- def __pow__ (self , other : Array | complex , / ) -> Array :
901+ def __pow__ (self , other : Array | int | float | complex , / ) -> Array :
902902 """
903903 Performs the operation __pow__.
904904 """
@@ -942,7 +942,7 @@ def __setitem__(
942942 np_key = key ._array if isinstance (key , Array ) else key
943943 self ._array .__setitem__ (np_key , asarray (value )._array )
944944
945- def __sub__ (self , other : Array | complex , / ) -> Array :
945+ def __sub__ (self , other : Array | int | float | complex , / ) -> Array :
946946 """
947947 Performs the operation __sub__.
948948 """
@@ -956,7 +956,7 @@ def __sub__(self, other: Array | complex, /) -> Array:
956956
957957 # PEP 484 requires int to be a subtype of float, but __truediv__ should
958958 # not accept int.
959- def __truediv__ (self , other : Array | complex , / ) -> Array :
959+ def __truediv__ (self , other : Array | int | float | complex , / ) -> Array :
960960 """
961961 Performs the operation __truediv__.
962962 """
@@ -968,7 +968,7 @@ def __truediv__(self, other: Array | complex, /) -> Array:
968968 res = self ._array .__truediv__ (other ._array )
969969 return self .__class__ ._new (res , device = self .device )
970970
971- def __xor__ (self , other : Array | int , / ) -> Array :
971+ def __xor__ (self , other : Array | bool | int , / ) -> Array :
972972 """
973973 Performs the operation __xor__.
974974 """
@@ -980,7 +980,7 @@ def __xor__(self, other: Array | int, /) -> Array:
980980 res = self ._array .__xor__ (other ._array )
981981 return self .__class__ ._new (res , device = self .device )
982982
983- def __iadd__ (self , other : Array | complex , / ) -> Array :
983+ def __iadd__ (self , other : Array | int | float | complex , / ) -> Array :
984984 """
985985 Performs the operation __iadd__.
986986 """
@@ -991,7 +991,7 @@ def __iadd__(self, other: Array | complex, /) -> Array:
991991 self ._array .__iadd__ (other ._array )
992992 return self
993993
994- def __radd__ (self , other : Array | complex , / ) -> Array :
994+ def __radd__ (self , other : Array | int | float | complex , / ) -> Array :
995995 """
996996 Performs the operation __radd__.
997997 """
@@ -1003,7 +1003,7 @@ def __radd__(self, other: Array | complex, /) -> Array:
10031003 res = self ._array .__radd__ (other ._array )
10041004 return self .__class__ ._new (res , device = self .device )
10051005
1006- def __iand__ (self , other : Array | int , / ) -> Array :
1006+ def __iand__ (self , other : Array | bool | int , / ) -> Array :
10071007 """
10081008 Performs the operation __iand__.
10091009 """
@@ -1014,7 +1014,7 @@ def __iand__(self, other: Array | int, /) -> Array:
10141014 self ._array .__iand__ (other ._array )
10151015 return self
10161016
1017- def __rand__ (self , other : Array | int , / ) -> Array :
1017+ def __rand__ (self , other : Array | bool | int , / ) -> Array :
10181018 """
10191019 Performs the operation __rand__.
10201020 """
@@ -1026,7 +1026,7 @@ def __rand__(self, other: Array | int, /) -> Array:
10261026 res = self ._array .__rand__ (other ._array )
10271027 return self .__class__ ._new (res , device = self .device )
10281028
1029- def __ifloordiv__ (self , other : Array | float , / ) -> Array :
1029+ def __ifloordiv__ (self , other : Array | int | float , / ) -> Array :
10301030 """
10311031 Performs the operation __ifloordiv__.
10321032 """
@@ -1037,7 +1037,7 @@ def __ifloordiv__(self, other: Array | float, /) -> Array:
10371037 self ._array .__ifloordiv__ (other ._array )
10381038 return self
10391039
1040- def __rfloordiv__ (self , other : Array | float , / ) -> Array :
1040+ def __rfloordiv__ (self , other : Array | int | float , / ) -> Array :
10411041 """
10421042 Performs the operation __rfloordiv__.
10431043 """
@@ -1098,7 +1098,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
10981098 res = self ._array .__rmatmul__ (other ._array )
10991099 return self .__class__ ._new (res , device = self .device )
11001100
1101- def __imod__ (self , other : Array | float , / ) -> Array :
1101+ def __imod__ (self , other : Array | int | float , / ) -> Array :
11021102 """
11031103 Performs the operation __imod__.
11041104 """
@@ -1108,7 +1108,7 @@ def __imod__(self, other: Array | float, /) -> Array:
11081108 self ._array .__imod__ (other ._array )
11091109 return self
11101110
1111- def __rmod__ (self , other : Array | float , / ) -> Array :
1111+ def __rmod__ (self , other : Array | int | float , / ) -> Array :
11121112 """
11131113 Performs the operation __rmod__.
11141114 """
@@ -1120,7 +1120,7 @@ def __rmod__(self, other: Array | float, /) -> Array:
11201120 res = self ._array .__rmod__ (other ._array )
11211121 return self .__class__ ._new (res , device = self .device )
11221122
1123- def __imul__ (self , other : Array | complex , / ) -> Array :
1123+ def __imul__ (self , other : Array | int | float | complex , / ) -> Array :
11241124 """
11251125 Performs the operation __imul__.
11261126 """
@@ -1130,7 +1130,7 @@ def __imul__(self, other: Array | complex, /) -> Array:
11301130 self ._array .__imul__ (other ._array )
11311131 return self
11321132
1133- def __rmul__ (self , other : Array | complex , / ) -> Array :
1133+ def __rmul__ (self , other : Array | int | float | complex , / ) -> Array :
11341134 """
11351135 Performs the operation __rmul__.
11361136 """
@@ -1142,7 +1142,7 @@ def __rmul__(self, other: Array | complex, /) -> Array:
11421142 res = self ._array .__rmul__ (other ._array )
11431143 return self .__class__ ._new (res , device = self .device )
11441144
1145- def __ior__ (self , other : Array | int , / ) -> Array :
1145+ def __ior__ (self , other : Array | bool | int , / ) -> Array :
11461146 """
11471147 Performs the operation __ior__.
11481148 """
@@ -1152,7 +1152,7 @@ def __ior__(self, other: Array | int, /) -> Array:
11521152 self ._array .__ior__ (other ._array )
11531153 return self
11541154
1155- def __ror__ (self , other : Array | int , / ) -> Array :
1155+ def __ror__ (self , other : Array | bool | int , / ) -> Array :
11561156 """
11571157 Performs the operation __ror__.
11581158 """
@@ -1164,7 +1164,7 @@ def __ror__(self, other: Array | int, /) -> Array:
11641164 res = self ._array .__ror__ (other ._array )
11651165 return self .__class__ ._new (res , device = self .device )
11661166
1167- def __ipow__ (self , other : Array | complex , / ) -> Array :
1167+ def __ipow__ (self , other : Array | int | float | complex , / ) -> Array :
11681168 """
11691169 Performs the operation __ipow__.
11701170 """
@@ -1174,7 +1174,7 @@ def __ipow__(self, other: Array | complex, /) -> Array:
11741174 self ._array .__ipow__ (other ._array )
11751175 return self
11761176
1177- def __rpow__ (self , other : Array | complex , / ) -> Array :
1177+ def __rpow__ (self , other : Array | int | float | complex , / ) -> Array :
11781178 """
11791179 Performs the operation __rpow__.
11801180 """
@@ -1209,7 +1209,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
12091209 res = self ._array .__rrshift__ (other ._array )
12101210 return self .__class__ ._new (res , device = self .device )
12111211
1212- def __isub__ (self , other : Array | complex , / ) -> Array :
1212+ def __isub__ (self , other : Array | int | float | complex , / ) -> Array :
12131213 """
12141214 Performs the operation __isub__.
12151215 """
@@ -1219,7 +1219,7 @@ def __isub__(self, other: Array | complex, /) -> Array:
12191219 self ._array .__isub__ (other ._array )
12201220 return self
12211221
1222- def __rsub__ (self , other : Array | complex , / ) -> Array :
1222+ def __rsub__ (self , other : Array | int | float | complex , / ) -> Array :
12231223 """
12241224 Performs the operation __rsub__.
12251225 """
@@ -1231,7 +1231,7 @@ def __rsub__(self, other: Array | complex, /) -> Array:
12311231 res = self ._array .__rsub__ (other ._array )
12321232 return self .__class__ ._new (res , device = self .device )
12331233
1234- def __itruediv__ (self , other : Array | complex , / ) -> Array :
1234+ def __itruediv__ (self , other : Array | int | float | complex , / ) -> Array :
12351235 """
12361236 Performs the operation __itruediv__.
12371237 """
@@ -1241,7 +1241,7 @@ def __itruediv__(self, other: Array | complex, /) -> Array:
12411241 self ._array .__itruediv__ (other ._array )
12421242 return self
12431243
1244- def __rtruediv__ (self , other : Array | complex , / ) -> Array :
1244+ def __rtruediv__ (self , other : Array | int | float | complex , / ) -> Array :
12451245 """
12461246 Performs the operation __rtruediv__.
12471247 """
@@ -1253,7 +1253,7 @@ def __rtruediv__(self, other: Array | complex, /) -> Array:
12531253 res = self ._array .__rtruediv__ (other ._array )
12541254 return self .__class__ ._new (res , device = self .device )
12551255
1256- def __ixor__ (self , other : Array | int , / ) -> Array :
1256+ def __ixor__ (self , other : Array | bool | int , / ) -> Array :
12571257 """
12581258 Performs the operation __ixor__.
12591259 """
@@ -1263,7 +1263,7 @@ def __ixor__(self, other: Array | int, /) -> Array:
12631263 self ._array .__ixor__ (other ._array )
12641264 return self
12651265
1266- def __rxor__ (self , other : Array | int , / ) -> Array :
1266+ def __rxor__ (self , other : Array | bool | int , / ) -> Array :
12671267 """
12681268 Performs the operation __rxor__.
12691269 """
0 commit comments