Skip to content

Commit 3ee28cc

Browse files
keep device in clone()
Co-authored-by: Alexandru Fikl <[email protected]>
1 parent 340f9dc commit 3ee28cc

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

arraycontext/impl/cupy/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def __init__(self, device: int | None = None) -> None:
7575
super().__init__()
7676
self._loopy_transform_cache = {}
7777

78+
self.device = device
79+
7880
if device is not None:
7981
import cupy as cp
8082
cp.cuda.runtime.setDevice(device)
@@ -88,7 +90,7 @@ def _get_fake_numpy_namespace(self):
8890
# {{{ ArrayContext interface
8991

9092
def clone(self):
91-
return type(self)()
93+
return type(self)(self.device)
9294

9395
@overload
9496
def from_numpy(self, array: np.ndarray) -> Array:

0 commit comments

Comments
 (0)