Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,13 @@ def __array__(self, dtype=None, context=None, copy=None):

def __dlpack__(self, *, stream: int | Any | None = None):
if len(self._arrays) != 1:
raise ValueError("__dlpack__ only supported for unsharded arrays.")
raise BufferError("__dlpack__ only supported for unsharded arrays.")
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self, stream=stream)

def __dlpack_device__(self) -> tuple[enum.Enum, int]:
if len(self._arrays) != 1:
raise ValueError("__dlpack__ only supported for unsharded arrays.")
raise BufferError("__dlpack__ only supported for unsharded arrays.")

from jax._src.dlpack import DLDeviceType # pylint: disable=g-import-not-at-top

Expand All @@ -426,17 +426,17 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]:
elif "rocm" in platform_version:
dl_device_type = DLDeviceType.kDLROCM
else:
raise ValueError("Unknown GPU platform for __dlpack__: "
raise BufferError("Unknown GPU platform for __dlpack__: "
f"{platform_version}")

local_hardware_id = _get_device(self).local_hardware_id
if local_hardware_id is None:
raise ValueError("Couldn't get local_hardware_id for __dlpack__")
raise BufferError("Couldn't get local_hardware_id for __dlpack__")

return dl_device_type, local_hardware_id

else:
raise ValueError(
raise BufferError(
"__dlpack__ device only supported for CPU and GPU, got platform: "
f"{self.platform()}"
)
Expand Down