@@ -204,35 +204,31 @@ def result_type(
204204 # required by the spec rather than using np.result_type. NumPy implements
205205 # too many extra type promotions like int64 + uint64 -> float64, and does
206206 # value-based casting on scalar arrays.
207- A = []
207+ dtypes = []
208208 scalars = []
209209 for a in arrays_and_dtypes :
210- if isinstance (a , Array ):
211- a = a .dtype
210+ if isinstance (a , DType ):
211+ dtypes .append (a )
212+ elif isinstance (a , Array ):
213+ dtypes .append (a .dtype )
212214 elif isinstance (a , (bool , int , float , complex )):
213215 scalars .append (a )
214- elif isinstance (a , np .ndarray ) or a not in _all_dtypes :
215- raise TypeError ("result_type() inputs must be array_api arrays or dtypes" )
216- A .append (a )
217-
218- # remove python scalars
219- B = [a for a in A if not isinstance (a , (bool , int , float , complex ))]
216+ else :
217+ raise TypeError (
218+ "result_type() inputs must be Array API arrays, dtypes, or scalars"
219+ )
220220
221- if len ( B ) == 0 :
221+ if not dtypes :
222222 raise ValueError ("at least one array or dtype is required" )
223- elif len (B ) == 1 :
224- result = B [0 ]
225- else :
226- t = B [0 ]
227- for t2 in B [1 :]:
228- t = _result_type (t , t2 )
229- result = t
223+ result = dtypes [0 ]
224+ for t2 in dtypes [1 :]:
225+ result = _result_type (result , t2 )
230226
231- if len ( scalars ) == 0 :
227+ if not scalars :
232228 return result
233229
234230 if get_array_api_strict_flags ()['api_version' ] <= '2023.12' :
235- raise TypeError ("result_type() inputs must be array_api arrays or dtypes" )
231+ raise TypeError ("result_type() inputs must be Array API arrays or dtypes" )
236232
237233 # promote python scalars given the result_type for all arrays/dtypes
238234 from ._creation_functions import empty
0 commit comments