Skip to content

Commit f3ff788

Browse files
committed
Array4: __cuda_array_interface__ v2
Start implementing the `__cuda_array_interface__` for zero-copy data exchange on Nvidia CUDA GPUs.
1 parent b8eae7f commit f3ff788

File tree

3 files changed

+214
-71
lines changed

3 files changed

+214
-71
lines changed

src/Base/Array4.cpp

Lines changed: 81 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,61 @@ namespace py = pybind11;
1919
using namespace amrex;
2020

2121

22+
namespace
23+
{
24+
/** CPU: __array_interface__ v3
25+
*
26+
* https://numpy.org/doc/stable/reference/arrays.interface.html
27+
*/
28+
template<typename T>
29+
py::dict
30+
array_interface(Array4<T> const & a4)
31+
{
32+
auto d = py::dict();
33+
auto const len = length(a4);
34+
// F->C index conversion here
35+
// p[(i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride+n*nstride];
36+
// Buffer dimensions: zero-size shall not skip dimension
37+
auto shape = py::make_tuple(
38+
a4.ncomp,
39+
len.z <= 0 ? 1 : len.z,
40+
len.y <= 0 ? 1 : len.y,
41+
len.x <= 0 ? 1 : len.x // fastest varying index
42+
);
43+
// buffer protocol strides are in bytes, AMReX strides are elements
44+
auto const strides = py::make_tuple(
45+
sizeof(T) * a4.nstride,
46+
sizeof(T) * a4.kstride,
47+
sizeof(T) * a4.jstride,
48+
sizeof(T) // fastest varying index
49+
);
50+
bool const read_only = false;
51+
d["data"] = py::make_tuple(std::intptr_t(a4.dataPtr()), read_only);
52+
// note: if we want to keep the same global indexing with non-zero
53+
// box small_end as in AMReX, then we can explore playing with
54+
// this offset as well
55+
//d["offset"] = 0; // default
56+
//d["mask"] = py::none(); // default
57+
58+
d["shape"] = shape;
59+
// we could also set this after checking the strides are C-style contiguous:
60+
//if (is_contiguous<T>(shape, strides))
61+
// d["strides"] = py::none(); // C-style contiguous
62+
//else
63+
d["strides"] = strides;
64+
65+
// type description
66+
// for more complicated types, e.g., tuples/structs
67+
//d["descr"] = ...;
68+
// we currently only need this
69+
d["typestr"] = py::format_descriptor<T>::format();
70+
71+
d["version"] = 3;
72+
return d;
73+
}
74+
}
75+
76+
2277
template< typename T >
2378
void make_Array4(py::module &m, std::string typestr)
2479
{
@@ -85,56 +140,44 @@ void make_Array4(py::module &m, std::string typestr)
85140
return a4;
86141
}))
87142

143+
144+
// CPU: __array_interface__ v3
145+
// https://numpy.org/doc/stable/reference/arrays.interface.html
88146
.def_property_readonly("__array_interface__", [](Array4<T> const & a4) {
89-
auto d = py::dict();
90-
auto const len = length(a4);
91-
// F->C index conversion here
92-
// p[(i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride+n*nstride];
93-
// Buffer dimensions: zero-size shall not skip dimension
94-
auto shape = py::make_tuple(
95-
a4.ncomp,
96-
len.z <= 0 ? 1 : len.z,
97-
len.y <= 0 ? 1 : len.y,
98-
len.x <= 0 ? 1 : len.x // fastest varying index
99-
);
100-
// buffer protocol strides are in bytes, AMReX strides are elements
101-
auto const strides = py::make_tuple(
102-
sizeof(T) * a4.nstride,
103-
sizeof(T) * a4.kstride,
104-
sizeof(T) * a4.jstride,
105-
sizeof(T) // fastest varying index
106-
);
107-
bool const read_only = false;
108-
d["data"] = py::make_tuple(std::intptr_t(a4.dataPtr()), read_only);
109-
// note: if we want to keep the same global indexing with non-zero
110-
// box small_end as in AMReX, then we can explore playing with
111-
// this offset as well
112-
//d["offset"] = 0; // default
113-
//d["mask"] = py::none(); // default
114-
115-
d["shape"] = shape;
116-
// we could also set this after checking the strides are C-style contiguous:
117-
//if (is_contiguous<T>(shape, strides))
118-
// d["strides"] = py::none(); // C-style contiguous
119-
//else
120-
d["strides"] = strides;
121-
122-
d["typestr"] = py::format_descriptor<T>::format();
123-
d["version"] = 3;
124-
return d;
147+
return array_interface(a4);
125148
})
126149

150+
// CPU: __array_function__ interface (TODO)
151+
//
152+
// NEP 18 — A dispatch mechanism for NumPy's high level array functions.
153+
// https://numpy.org/neps/nep-0018-array-function-protocol.html
154+
// This enables code using NumPy to be directly operated on Array4 arrays.
155+
// __array_function__ feature requires NumPy 1.16 or later.
156+
127157

128-
// TODO: __cuda_array_interface__
158+
// Nvidia GPUs: __cuda_array_interface__ v2
129159
// https://numba.readthedocs.io/en/latest/cuda/cuda_array_interface.html
160+
.def_property_readonly("__cuda_array_interface__", [](Array4<T> const & a4) {
161+
auto d = array_interface(a4);
162+
163+
// data:
164+
// Because the user of the interface may or may not be in the same context, the most common case is to use cuPointerGetAttribute with CU_POINTER_ATTRIBUTE_DEVICE_POINTER in the CUDA driver API (or the equivalent CUDA Runtime API) to retrieve a device pointer that is usable in the currently active context.
165+
// TODO For zero-size arrays, use 0 here.
166+
167+
// ... TODO: wasn't there some stream or device info?
168+
169+
d["version"] = 2;
170+
return d;
171+
})
130172

131173

132-
// TODO: __dlpack__
174+
// TODO: __dlpack__ __dlpack_device__
133175
// DLPack protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
134176
// https://dmlc.github.io/dlpack/latest/
135177
// https://data-apis.org/array-api/latest/design_topics/data_interchange.html
136178
// https://github.com/data-apis/consortium-feedback/issues/1
137179
// https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
180+
// https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol
138181

139182

140183
.def("contains", &Array4<T>::contains)

tests/test_array4.py

Lines changed: 80 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -69,36 +69,83 @@ def test_array4():
6969
x[1, 1, 1] = 44
7070
assert v_carr2np[0, 1, 1, 1] == 44
7171

72-
# from cupy
73-
74-
# to numpy
75-
76-
# to cupy
77-
78-
return
79-
80-
# Check indexing
81-
assert obj[0] == 1
82-
assert obj[1] == 2
83-
assert obj[2] == 3
84-
assert obj[-1] == 3
85-
assert obj[-2] == 2
86-
assert obj[-3] == 1
87-
with pytest.raises(IndexError):
88-
obj[-4]
89-
with pytest.raises(IndexError):
90-
obj[3]
91-
92-
# Check assignment
93-
obj[0] = 2
94-
obj[1] = 3
95-
obj[2] = 4
96-
assert obj[0] == 2
97-
assert obj[1] == 3
98-
assert obj[2] == 4
99-
100-
101-
# def test_iv_conversions():
102-
# obj = amrex.IntVect.max_vector().numpy()
103-
# assert(isinstance(obj, np.ndarray))
104-
# assert(obj.dtype == np.int32)
72+
73+
@pytest.mark.skipif(
74+
amrex.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
75+
)
76+
def test_array4_numba():
77+
# https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html
78+
import numba
79+
80+
# numba -> AMReX Array4
81+
x = np.ones(
82+
(
83+
2,
84+
3,
85+
4,
86+
)
87+
) # type: numpy.ndarray
88+
89+
# host-to-device copy
90+
x_numba = numba.cuda.to_device(
91+
x
92+
) # type: numba.cuda.cudadrv.devicearray.DeviceNDArray
93+
# x_cupy = cupy.asarray(x_numba) # type: cupy.ndarray
94+
x_arr = amrex.Array4_double(x_numba) # type: amrex.Array4_double
95+
96+
assert (
97+
x_arr.__cuda_array_interface__["data"][0]
98+
== x_numba.__cuda_array_interface__["data"][0]
99+
)
100+
101+
# AMReX -> numba
102+
# arr_numba = cuda.as_cuda_array(arr4)
103+
# ... or as MultiFab test
104+
# TODO
105+
106+
107+
@pytest.mark.skipif(
108+
amrex.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
109+
)
110+
def test_array4_cupy():
111+
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html
112+
import cupy as cp
113+
114+
# cupy -> AMReX Array4
115+
x = np.ones(
116+
(
117+
2,
118+
3,
119+
4,
120+
)
121+
) # TODO: merge into next line and create on device?
122+
x_cupy = cp.asarray(x) # type: cupy.ndarray
123+
print(f"x_cupy={x_cupy}")
124+
print(x_cupy.__cuda_array_interface__)
125+
126+
# cupy -> AMReX array4
127+
x_arr = amrex.Array4_double(x_cupy) # type: amrex.Array4_double
128+
print(f"x_arr={x_arr}")
129+
print(x_arr.__cuda_array_interface__)
130+
131+
assert (
132+
x_arr.__cuda_array_interface__["data"][0]
133+
== x_cupy.__cuda_array_interface__["data"][0]
134+
)
135+
136+
# AMReX -> cupy
137+
# arr_numba = cuda.as_cuda_array(arr4)
138+
# ... or as MultiFab test
139+
# TODO
140+
141+
142+
@pytest.mark.skipif(
143+
amrex.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
144+
)
145+
def test_array4_pytorch():
146+
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html#pytorch
147+
# arr_torch = torch.as_tensor(arr, device='cuda')
148+
# assert(arr_torch.__cuda_array_interface__['data'][0] == arr.__cuda_array_interface__['data'][0])
149+
# TODO
150+
151+
pass

tests/test_multifab.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,56 @@ def test_mfab_mfiter(mfab, nghost):
152152
cnt += 1
153153

154154
assert iter(mfab).length == cnt
155+
156+
157+
@pytest.mark.skipif(
158+
amrex.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
159+
)
160+
def test_mfab_ops_cuda_numba():
161+
# https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html
162+
import numba
163+
164+
# AMReX -> numba
165+
# arr_numba = cuda.as_cuda_array(arr4)
166+
# TODO
167+
168+
169+
@pytest.mark.skipif(
170+
amrex.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
171+
)
172+
def test_mfab_ops_cuda_cupy():
173+
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html
174+
import cupy as cp
175+
176+
# AMReX -> cupy
177+
# arr_numba = cuda.as_cuda_array(arr4)
178+
# TODO
179+
180+
181+
@pytest.mark.skipif(
182+
amrex.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
183+
)
184+
def test_mfab_ops_cuda_pytorch():
185+
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html#pytorch
186+
import torch
187+
188+
# AMReX -> pytorch
189+
# arr_torch = torch.as_tensor(arr, device='cuda')
190+
# assert(arr_torch.__cuda_array_interface__['data'][0] == arr.__cuda_array_interface__['data'][0])
191+
# TODO
192+
193+
194+
@pytest.mark.skipif(
195+
amrex.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
196+
)
197+
def test_mfab_ops_cuda_cuml():
198+
# https://github.com/rapidsai/cuml
199+
# https://github.com/rapidsai/cudf
200+
# maybe better for particles as a dataframe test
201+
import cudf
202+
import cuml
203+
204+
# AMReX -> RAPIDSAI cuML
205+
# arr_cuml = ...
206+
# assert(arr_cuml.__cuda_array_interface__['data'][0] == arr.__cuda_array_interface__['data'][0])
207+
# TODO

0 commit comments

Comments
 (0)