diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index e9e03d3b894cb..81e61b1198b96 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -238,6 +238,16 @@ class wi_element { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + + explicit operator bool() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast(0); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + wi_element &operator=(const T &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); @@ -248,6 +258,80 @@ class wi_element { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + + wi_element & + operator=(const wi_element &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend T operator+(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) + rhs; +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator+=(const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) + rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend T operator-(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) - rhs; +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator-=(const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) - rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend T operator*(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) * rhs; +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + wi_element &operator*=(const T &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic( @@ -259,7 +343,108 @@ class wi_element { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } - // TODO: add other arithmetic operators + + friend T operator/(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) / rhs; +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator/=(const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) / rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator<(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) < rhs; +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator<=(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) <= rhs; +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator>(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) > rhs; +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator>=(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) >= rhs; +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator==(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) == rhs; +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator!=(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) != rhs; +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } }; template