From 642575a851f9b905e9c3f39807c8bad5ce322be0 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 7 Jan 2022 14:01:45 +0800 Subject: [PATCH 1/2] [SYCL][Matrix] Add more operators for wi_element 1. We add binary operators: +, -, *, /=, and binary assignment operators +=, -=, /= 2. Add support for logical operators: >, >=, <, <=, ==, != 3. Add support for explicit operator bool() --- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 172 ++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index e9e03d3b894cb..dbd3c72107622 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -238,6 +238,16 @@ class wi_element { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + + explicit operator bool() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast(0); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + wi_element &operator=(const T &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); @@ -248,6 +258,75 @@ class wi_element { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + wi_element & + operator=(const wi_element &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend T operator+(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) + rhs; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + wi_element &operator+=(const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) + rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend T operator-(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) - rhs; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator-=(const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) - rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend T operator*(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) * rhs; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + wi_element &operator*=(const T &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic( @@ -259,6 +338,99 @@ class wi_element { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + friend T operator/(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) / rhs; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator/=(const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) / rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + friend bool + operator<(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) < rhs; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator<=(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) <= rhs; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator>(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) > rhs; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator>=(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) >= rhs; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + friend bool + operator==(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) == rhs; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator!=(const wi_element &lhs, + const T &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) != rhs; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + // TODO: add other arithmetic operators }; From 17ca9e99303dec6e46df36500a1a667b7666796b Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Sat, 8 Jan 2022 00:41:17 +0800 Subject: [PATCH 2/2] address comments --- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index dbd3c72107622..81e61b1198b96 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -258,6 +258,7 @@ class wi_element { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + wi_element & operator=(const wi_element &rhs) { #ifdef __SYCL_DEVICE_ONLY__ @@ -276,11 +277,13 @@ class wi_element { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) + rhs; #else + (void)lhs; (void)rhs; throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + wi_element &operator+=(const T &rhs) { #ifdef __SYCL_DEVICE_ONLY__ M.spvm = __spirv_VectorInsertDynamic( @@ -298,6 +301,7 @@ class wi_element { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) - rhs; #else + (void)lhs; (void)rhs; throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); @@ -321,6 +325,7 @@ class wi_element { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) * rhs; #else + (void)lhs; (void)rhs; throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); @@ -338,11 +343,13 @@ class wi_element { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + friend T operator/(const wi_element &lhs, const T &rhs) { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) / rhs; #else + (void)lhs; (void)rhs; throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); @@ -360,12 +367,14 @@ class wi_element { PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + friend bool operator<(const wi_element &lhs, const T &rhs) { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) < rhs; #else + (void)lhs; (void)rhs; throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); @@ -378,6 +387,7 @@ class wi_element { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) <= rhs; #else + (void)lhs; (void)rhs; throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); @@ -390,6 +400,7 @@ class wi_element { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) > rhs; #else + (void)lhs; (void)rhs; throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); @@ -402,17 +413,20 @@ class wi_element { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) >= rhs; #else + (void)lhs; (void)rhs; throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } + friend bool operator==(const wi_element &lhs, const T &rhs) { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) == rhs; #else + (void)lhs; (void)rhs; throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); @@ -425,13 +439,12 @@ class wi_element { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) != rhs; #else + (void)lhs; (void)rhs; throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); #endif // __SYCL_DEVICE_ONLY__ } - - // TODO: add other arithmetic operators }; template