Skip to content

Commit 9d7feae

Browse files
committed
Pass allocator when converting scalar args to dev buffers
1 parent b5feb06 commit 9d7feae

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

arraycontext/impl/pytato/compile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,8 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
527527
if np.isscalar(arg):
528528
if isinstance(actx, PytatoPyOpenCLArrayContext):
529529
import pyopencl.array as cla
530-
arg = cla.to_device(actx.queue, np.array(arg))
530+
arg = cla.to_device(actx.queue, np.array(arg),
531+
allocator=actx.allocator)
531532
elif isinstance(actx, PytatoJAXArrayContext):
532533
import jax
533534
arg = jax.device_put(arg)

0 commit comments

Comments
 (0)