@@ -54,6 +54,9 @@ def _create_strided_array(buf, numdims, idims, dtype, is_device, offset, strides
5454
5555def _create_empty_array (numdims , idims , dtype ):
5656 out_arr = c_void_ptr_t (0 )
57+
58+ if numdims == 0 : return out_arr
59+
5760 c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
5861 safe_call (backend .get ().af_create_handle (c_pointer (out_arr ),
5962 numdims , c_pointer (c_dims ), dtype .value ))
@@ -160,19 +163,18 @@ def _slice_to_length(key, dim):
160163
161164def _get_info (dims , buf_len ):
162165 elements = 1
163- numdims = len (dims )
164- idims = [1 ]* 4
165-
166- for i in range (numdims ):
167- elements *= dims [i ]
168- idims [i ] = dims [i ]
169-
170- if (elements == 0 ):
171- if (buf_len != 0 ):
172- idims = [buf_len , 1 , 1 , 1 ]
173- numdims = 1
174- else :
175- raise RuntimeError ("Invalid size" )
166+ numdims = 0
167+ if dims :
168+ numdims = len (dims )
169+ idims = [1 ]* 4
170+ for i in range (numdims ):
171+ elements *= dims [i ]
172+ idims [i ] = dims [i ]
173+ elif (buf_len != 0 ):
174+ idims = [buf_len , 1 , 1 , 1 ]
175+ numdims = 1
176+ else :
177+ raise RuntimeError ("Invalid size" )
176178
177179 return numdims , idims
178180
@@ -382,7 +384,7 @@ class Array(BaseArray):
382384 # arrayfire's __radd__() instead of numpy's __add__()
383385 __array_priority__ = 30
384386
385- def __init__ (self , src = None , dims = ( 0 ,) , dtype = None , is_device = False , offset = None , strides = None ):
387+ def __init__ (self , src = None , dims = None , dtype = None , is_device = False , offset = None , strides = None ):
386388
387389 super (Array , self ).__init__ ()
388390
@@ -449,10 +451,12 @@ def __init__(self, src=None, dims=(0,), dtype=None, is_device=False, offset=None
449451 if type_char is None :
450452 type_char = 'f'
451453
452- numdims = len (dims )
454+ numdims = len (dims ) if dims else 0
455+
453456 idims = [1 ] * 4
454457 for n in range (numdims ):
455458 idims [n ] = dims [n ]
459+
456460 self .arr = _create_empty_array (numdims , idims , to_dtype [type_char ])
457461
458462 def as_type (self , ty ):
0 commit comments