2222try :
2323 # torch >=2.3
2424 _int_dtypes |= {torch .uint16 , torch .uint32 , torch .uint64 }
25+ _HAS_LARGE_UINT = True
2526except AttributeError :
26- pass
27-
27+ _HAS_LARGE_UINT = False
2828
2929_array_api_dtypes = {
3030 torch .bool ,
3535 torch .complex128 ,
3636}
3737
38- _promotion_table = {
39- # bool
40- (torch .bool , torch .bool ): torch .bool ,
38+ _promotion_table = {
4139 # ints
42- (torch .int8 , torch .int8 ): torch .int8 ,
4340 (torch .int8 , torch .int16 ): torch .int16 ,
4441 (torch .int8 , torch .int32 ): torch .int32 ,
4542 (torch .int8 , torch .int64 ): torch .int64 ,
46- (torch .int16 , torch .int8 ): torch .int16 ,
47- (torch .int16 , torch .int16 ): torch .int16 ,
4843 (torch .int16 , torch .int32 ): torch .int32 ,
4944 (torch .int16 , torch .int64 ): torch .int64 ,
50- (torch .int32 , torch .int8 ): torch .int32 ,
51- (torch .int32 , torch .int16 ): torch .int32 ,
52- (torch .int32 , torch .int32 ): torch .int32 ,
5345 (torch .int32 , torch .int64 ): torch .int64 ,
54- (torch .int64 , torch .int8 ): torch .int64 ,
55- (torch .int64 , torch .int16 ): torch .int64 ,
56- (torch .int64 , torch .int32 ): torch .int64 ,
57- (torch .int64 , torch .int64 ): torch .int64 ,
58- # uints
59- (torch .uint8 , torch .uint8 ): torch .uint8 ,
6046 # ints and uints (mixed sign)
61- (torch .int8 , torch .uint8 ): torch .int16 ,
62- (torch .int16 , torch .uint8 ): torch .int16 ,
63- (torch .int32 , torch .uint8 ): torch .int32 ,
64- (torch .int64 , torch .uint8 ): torch .int64 ,
6547 (torch .uint8 , torch .int8 ): torch .int16 ,
6648 (torch .uint8 , torch .int16 ): torch .int16 ,
6749 (torch .uint8 , torch .int32 ): torch .int32 ,
6850 (torch .uint8 , torch .int64 ): torch .int64 ,
6951 # floats
70- (torch .float32 , torch .float32 ): torch .float32 ,
7152 (torch .float32 , torch .float64 ): torch .float64 ,
72- (torch .float64 , torch .float32 ): torch .float64 ,
73- (torch .float64 , torch .float64 ): torch .float64 ,
7453 # complexes
75- (torch .complex64 , torch .complex64 ): torch .complex64 ,
7654 (torch .complex64 , torch .complex128 ): torch .complex128 ,
77- (torch .complex128 , torch .complex64 ): torch .complex128 ,
78- (torch .complex128 , torch .complex128 ): torch .complex128 ,
7955 # Mixed float and complex
8056 (torch .float32 , torch .complex64 ): torch .complex64 ,
8157 (torch .float32 , torch .complex128 ): torch .complex128 ,
8258 (torch .float64 , torch .complex64 ): torch .complex128 ,
8359 (torch .float64 , torch .complex128 ): torch .complex128 ,
8460}
8561
62+ if _HAS_LARGE_UINT : # torch >=2.3
63+ _promotion_table .update (
64+ {
65+ # uints
66+ (torch .uint8 , torch .uint16 ): torch .uint16 ,
67+ (torch .uint8 , torch .uint32 ): torch .uint32 ,
68+ (torch .uint8 , torch .uint64 ): torch .uint64 ,
69+ (torch .uint16 , torch .uint32 ): torch .uint32 ,
70+ (torch .uint16 , torch .uint64 ): torch .uint64 ,
71+ (torch .uint32 , torch .uint64 ): torch .uint64 ,
72+ # ints and uints (mixed sign)
73+ (torch .uint16 , torch .int8 ): torch .int32 ,
74+ (torch .uint16 , torch .int16 ): torch .int32 ,
75+ (torch .uint16 , torch .int32 ): torch .int32 ,
76+ (torch .uint16 , torch .int64 ): torch .int64 ,
77+ (torch .uint32 , torch .int8 ): torch .int64 ,
78+ (torch .uint32 , torch .int16 ): torch .int64 ,
79+ (torch .uint32 , torch .int32 ): torch .int64 ,
80+ (torch .uint32 , torch .int64 ): torch .int64 ,
81+ }
82+ )
83+
84+ _promotion_table .update ({(b , a ): c for (a , b ), c in _promotion_table .items ()})
85+ _promotion_table .update ({(a , a ): a for a in _array_api_dtypes })
86+
8687
8788def _two_arg (f ):
8889 @_wraps (f )
@@ -301,27 +302,41 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
301302 out = torch .unsqueeze (out , a )
302303 return out
303304
305+
306+ def _sum_prod_no_axis (x : Array , dtype : DType | None ) -> Array :
307+ """
308+ Implements `sum(..., axis=())` and `prod(..., axis=())`.
309+
310+ Works around https://github.com/pytorch/pytorch/issues/29137
311+ """
312+ if dtype is not None :
313+ return x .clone () if dtype == x .dtype else x .to (dtype )
314+
315+ if x .dtype in (torch .int8 , torch .int16 , torch .int32 ):
316+ return x .to (torch .int64 )
317+
318+ if _HAS_LARGE_UINT and x .dtype in (torch .uint8 , torch .uint16 , torch .uint32 ):
319+ return x .to (torch .uint64 )
320+
321+ if x .dtype == torch .uint8 :
322+ # We can't upcast uint8 according to the spec because there is no
323+ # torch.uint64, so at least upcast to int64 which is what prod does
324+ # when axis=None.
325+ return x .to (torch .int64 )
326+
327+ return x .clone ()
328+
329+
304330def prod (x : Array ,
305331 / ,
306332 * ,
307333 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
308334 dtype : Optional [DType ] = None ,
309335 keepdims : bool = False ,
310336 ** kwargs ) -> Array :
311- ndim = x .ndim
312337
313- # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
314- # below because it still needs to upcast.
315338 if axis == ():
316- if dtype is None :
317- # We can't upcast uint8 according to the spec because there is no
318- # torch.uint64, so at least upcast to int64 which is what sum does
319- # when axis=None.
320- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
321- return x .to (torch .int64 )
322- return x .clone ()
323- return x .to (dtype )
324-
339+ return _sum_prod_no_axis (x , dtype )
325340 # torch.prod doesn't support multiple axes
326341 # (https://github.com/pytorch/pytorch/issues/56586).
327342 if isinstance (axis , tuple ):
@@ -330,7 +345,7 @@ def prod(x: Array,
330345 # torch doesn't support keepdims with axis=None
331346 # (https://github.com/pytorch/pytorch/issues/71209)
332347 res = torch .prod (x , dtype = dtype , ** kwargs )
333- res = _axis_none_keepdims (res , ndim , keepdims )
348+ res = _axis_none_keepdims (res , x . ndim , keepdims )
334349 return res
335350
336351 return torch .prod (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -343,25 +358,14 @@ def sum(x: Array,
343358 dtype : Optional [DType ] = None ,
344359 keepdims : bool = False ,
345360 ** kwargs ) -> Array :
346- ndim = x .ndim
347361
348- # https://github.com/pytorch/pytorch/issues/29137.
349- # Make sure it upcasts.
350362 if axis == ():
351- if dtype is None :
352- # We can't upcast uint8 according to the spec because there is no
353- # torch.uint64, so at least upcast to int64 which is what sum does
354- # when axis=None.
355- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
356- return x .to (torch .int64 )
357- return x .clone ()
358- return x .to (dtype )
359-
363+ return _sum_prod_no_axis (x , dtype )
360364 if axis is None :
361365 # torch doesn't support keepdims with axis=None
362366 # (https://github.com/pytorch/pytorch/issues/71209)
363367 res = torch .sum (x , dtype = dtype , ** kwargs )
364- res = _axis_none_keepdims (res , ndim , keepdims )
368+ res = _axis_none_keepdims (res , x . ndim , keepdims )
365369 return res
366370
367371 return torch .sum (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -372,7 +376,7 @@ def any(x: Array,
372376 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
373377 keepdims : bool = False ,
374378 ** kwargs ) -> Array :
375- ndim = x . ndim
379+
376380 if axis == ():
377381 return x .to (torch .bool )
378382 # torch.any doesn't support multiple axes
@@ -384,7 +388,7 @@ def any(x: Array,
384388 # torch doesn't support keepdims with axis=None
385389 # (https://github.com/pytorch/pytorch/issues/71209)
386390 res = torch .any (x , ** kwargs )
387- res = _axis_none_keepdims (res , ndim , keepdims )
391+ res = _axis_none_keepdims (res , x . ndim , keepdims )
388392 return res .to (torch .bool )
389393
390394 # torch.any doesn't return bool for uint8
@@ -396,7 +400,7 @@ def all(x: Array,
396400 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
397401 keepdims : bool = False ,
398402 ** kwargs ) -> Array :
399- ndim = x . ndim
403+
400404 if axis == ():
401405 return x .to (torch .bool )
402406 # torch.all doesn't support multiple axes
@@ -408,7 +412,7 @@ def all(x: Array,
408412 # torch doesn't support keepdims with axis=None
409413 # (https://github.com/pytorch/pytorch/issues/71209)
410414 res = torch .all (x , ** kwargs )
411- res = _axis_none_keepdims (res , ndim , keepdims )
415+ res = _axis_none_keepdims (res , x . ndim , keepdims )
412416 return res .to (torch .bool )
413417
414418 # torch.all doesn't return bool for uint8
0 commit comments