Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions arrayfire/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def _get_assign_dims(key, idims):
else:
raise IndexError("Invalid type while assigning to arrayfire.array")


def transpose(a, conj=False):
"""
Perform the transpose on an input.
Expand Down Expand Up @@ -504,7 +503,9 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
if(offset is None and strides is None):
self.arr = _create_array(buf, numdims, idims, to_dtype[_type_char], is_device)
else:
self.arr = _create_strided_array(buf, numdims, idims, to_dtype[_type_char], is_device, offset, strides)
self.arr = _create_strided_array(buf, numdims, idims,
to_dtype[_type_char],
is_device, offset, strides)

else:

Expand Down Expand Up @@ -1159,6 +1160,19 @@ def __setitem__(self, key, val):
except RuntimeError as e:
raise IndexError(str(e))

def _reorder(self):
"""
Returns a reordered array to help interoperate with row major formats.
"""
ndims = self.numdims()
if (ndims == 1):
return self

rdims = tuple(reversed(range(ndims))) + tuple(range(ndims, 4))
out = Array()
safe_call(backend.get().af_reorder(c_pointer(out.arr), self.arr, *rdims))
return out

def to_ctype(self, row_major=False, return_shape=False):
"""
Return the data as a ctype C array after copying to host memory
Expand All @@ -1184,9 +1198,11 @@ def to_ctype(self, row_major=False, return_shape=False):
if (self.arr.value == 0):
raise RuntimeError("Can not call to_ctype on empty array")

tmp = transpose(self) if row_major else self
tmp = self._reorder() if (row_major) else self

ctype_type = to_c_type[self.type()] * self.elements()
res = ctype_type()

safe_call(backend.get().af_get_data_ptr(c_pointer(res), self.arr))
if (return_shape):
return res, self.dims()
Expand Down Expand Up @@ -1312,6 +1328,44 @@ def __array__(self):
safe_call(backend.get().af_get_data_ptr(c_void_ptr_t(res.ctypes.data), self.arr))
return res

def to_ndarray(self, output=None):
"""
Parameters
-----------
output: optional: numpy. default: None

Returns
----------
If output is None: Constructs a numpy.array from arrayfire.Array
If output is not None: copies content of af.array into numpy array.

Note
------

- An exception is thrown when output is not None and it is not contiguous.
- When output is None, The returned array is in fortran contiguous order.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I prefer column major order rather than Fortran contiguous order.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@umar456 That's the notation numpy uses.

"""
if output is None:
return self.__array__()

if (output.dtype != to_typecode[self.type()]):
raise TypeError("Output is not the same type as the array")

if (output.size != self.elements()):
raise RuntimeError("Output size does not match that of input")

flags = output.flags
tmp = None
if flags['F_CONTIGUOUS']:
tmp = self
elif flags['C_CONTIGUOUS']:
tmp = self._reorder()
else:
raise RuntimeError("When output is not None, it must be contiguous")

safe_call(backend.get().af_get_data_ptr(c_void_ptr_t(output.ctypes.data), tmp.arr))
return output

def display(a, precision=4):
"""
Displays the contents of an array.
Expand Down
1 change: 0 additions & 1 deletion arrayfire/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def tile(a, d0, d1=1, d2=1, d3=1):
safe_call(backend.get().af_tile(c_pointer(out.arr), a.arr, d0, d1, d2, d3))
return out


def reorder(a, d0=1, d1=0, d2=2, d3=3):
"""
Reorder the dimensions of the input.
Expand Down
25 changes: 6 additions & 19 deletions arrayfire/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def _fc_to_af_array(in_ptr, in_shape, in_dtype, is_device=False, copy = True):
"""
res = Array(in_ptr, in_shape, in_dtype, is_device=is_device)

if is_device:
lock_array(res)
pass
if not is_device:
return res

lock_array(res)
return res.copy() if copy else res

def _cc_to_af_array(in_ptr, ndim, in_shape, in_dtype, is_device=False, copy = True):
Expand All @@ -41,24 +41,11 @@ def _cc_to_af_array(in_ptr, ndim, in_shape, in_dtype, is_device=False, copy = Tr
"""
if ndim == 1:
return _fc_to_af_array(in_ptr, in_shape, in_dtype, is_device, copy)
elif ndim == 2:
shape = (in_shape[1], in_shape[0])
res = Array(in_ptr, shape, in_dtype, is_device=is_device)
if is_device: lock_array(res)
return reorder(res, 1, 0)
elif ndim == 3:
shape = (in_shape[2], in_shape[1], in_shape[0])
res = Array(in_ptr, shape, in_dtype, is_device=is_device)
if is_device: lock_array(res)
return reorder(res, 2, 1, 0)
elif ndim == 4:
shape = (in_shape[3], in_shape[2], in_shape[1], in_shape[0])
else:
shape = tuple(reversed(in_shape))
res = Array(in_ptr, shape, in_dtype, is_device=is_device)
if is_device: lock_array(res)
return reorder(res, 3, 2, 1, 0)
else:
raise RuntimeError("Unsupported ndim")

return res._reorder()

_nptype_to_aftype = {'b1' : Dtype.b8,
'u1' : Dtype.u8,
Expand Down
14 changes: 13 additions & 1 deletion arrayfire/tests/simple/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,33 @@ def simple_interop(verbose = False):
a = af.to_array(n)
n2 = np.array(a)
assert((n==n2).all())
n2[:] = 0
a.to_ndarray(n2)
assert((n==n2).all())

n = np.random.random((5,3))
a = af.to_array(n)
n2 = np.array(a)
assert((n==n2).all())
n2[:] = 0
a.to_ndarray(n2)
assert((n==n2).all())

n = np.random.random((5,3,2))
a = af.to_array(n)
n2 = np.array(a)
assert((n==n2).all())
n2[:] = 0
a.to_ndarray(n2)
assert((n==n2).all())

n = np.random.random((5,3,2, 2))
n = np.random.random((5,3,2,2))
a = af.to_array(n)
n2 = np.array(a)
assert((n==n2).all())
n2[:] = 0
a.to_ndarray(n2)
assert((n==n2).all())

if af.AF_PYCUDA_FOUND and af.get_active_backend() == 'cuda':
import pycuda.autoinit
Expand Down