@@ -116,7 +116,9 @@ def _fix_promotion(x1, x2, only_scalar=True):
116116_py_scalars  =  (bool , int , float , complex )
117117
118118
119- def  result_type (* arrays_and_dtypes : Array  |  DType  |  complex ) ->  DType :
119+ def  result_type (
120+     * arrays_and_dtypes : Array  |  DType  |  bool  |  int  |  float  |  complex 
121+ ) ->  DType :
120122    num  =  len (arrays_and_dtypes )
121123
122124    if  num  ==  0 :
@@ -550,10 +552,16 @@ def count_nonzero(
550552        return  result 
551553
552554
553- def  where (condition : Array , x1 : Array , x2 : Array , / ) ->  Array :
555+ def  where (
556+     condition : Array , 
557+     x1 : Array  |  bool  |  int  |  float  |  complex , 
558+     x2 : Array  |  bool  |  int  |  float  |  complex ,
559+     / ,
560+ ) ->  Array :
554561    x1 , x2  =  _fix_promotion (x1 , x2 )
555562    return  torch .where (condition , x1 , x2 )
556563
564+ 
557565# torch.reshape doesn't have the copy keyword 
558566def  reshape (x : Array ,
559567            / ,
@@ -622,7 +630,7 @@ def linspace(start: Union[int, float],
622630# torch.full does not accept an int size 
623631# https://github.com/pytorch/pytorch/issues/70906 
624632def  full (shape : Union [int , Tuple [int , ...]],
625-          fill_value : complex ,
633+          fill_value : bool   |   int   |   float   |   complex ,
626634         * ,
627635         dtype : Optional [DType ] =  None ,
628636         device : Optional [Device ] =  None ,
0 commit comments