Skip to content

Commit ab8266d

Browse files
allow optional device selection
1 parent 6f3cd94 commit ab8266d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

arraycontext/impl/cupy/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ class CupyArrayContext(ArrayContext):
7171

7272
_loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]
7373

74-
def __init__(self) -> None:
74+
def __init__(self, device: int | None = None) -> None:
7575
super().__init__()
7676
self._loopy_transform_cache = {}
7777

78+
if device is not None:
79+
import cupy as cp
80+
cp.cuda.runtime.setDevice(device)
81+
7882
array_types = (CupyNonObjectArray,)
7983

8084
def _get_fake_numpy_namespace(self):

0 commit comments

Comments
 (0)