From 9924a0cb657c7322a3c5d328e1b89cb27ee77650 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 12 Mar 2024 12:56:22 +0000 Subject: [PATCH] Update --- jax/_src/array.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index ff13ab7acab9..07d55f018365 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -406,13 +406,13 @@ def __array__(self, dtype=None, context=None, copy=None): def __dlpack__(self, *, stream: int | Any | None = None): if len(self._arrays) != 1: - raise ValueError("__dlpack__ only supported for unsharded arrays.") + raise BufferError("__dlpack__ only supported for unsharded arrays.") from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top return to_dlpack(self, stream=stream) def __dlpack_device__(self) -> tuple[enum.Enum, int]: if len(self._arrays) != 1: - raise ValueError("__dlpack__ only supported for unsharded arrays.") + raise BufferError("__dlpack__ only supported for unsharded arrays.") from jax._src.dlpack import DLDeviceType # pylint: disable=g-import-not-at-top @@ -426,17 +426,17 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: elif "rocm" in platform_version: dl_device_type = DLDeviceType.kDLROCM else: - raise ValueError("Unknown GPU platform for __dlpack__: " + raise BufferError("Unknown GPU platform for __dlpack__: " f"{platform_version}") local_hardware_id = _get_device(self).local_hardware_id if local_hardware_id is None: - raise ValueError("Couldn't get local_hardware_id for __dlpack__") + raise BufferError("Couldn't get local_hardware_id for __dlpack__") return dl_device_type, local_hardware_id else: - raise ValueError( + raise BufferError( "__dlpack__ device only supported for CPU and GPU, got platform: " f"{self.platform()}" )