Skip to content

Commit fe2e35e

Browse files
[SYCL][Matrix] Add more operators for wi_element (#5270)
1. We add binary operators: +, -, *, /=, and binary assignment operators +=, -=, /= 2. Add support for logical operators: >, >=, <, <=, ==, != 3. Add support for explicit operator bool()
1 parent 33cfb9f commit fe2e35e

File tree

1 file changed

+186
-1
lines changed

1 file changed

+186
-1
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,16 @@ class wi_element {
238238
PI_INVALID_DEVICE);
239239
#endif // __SYCL_DEVICE_ONLY__
240240
}
241+
242+
explicit operator bool() {
243+
#ifdef __SYCL_DEVICE_ONLY__
244+
return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast<T>(0);
245+
#else
246+
throw runtime_error("joint matrix is not supported on host device.",
247+
PI_INVALID_DEVICE);
248+
#endif // __SYCL_DEVICE_ONLY__
249+
}
250+
241251
wi_element &operator=(const T &rhs) {
242252
#ifdef __SYCL_DEVICE_ONLY__
243253
M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
@@ -248,6 +258,80 @@ class wi_element {
248258
PI_INVALID_DEVICE);
249259
#endif // __SYCL_DEVICE_ONLY__
250260
}
261+
262+
wi_element &
263+
operator=(const wi_element<T, NumRows, NumCols, Layout, Group> &rhs) {
264+
#ifdef __SYCL_DEVICE_ONLY__
265+
M.spvm = __spirv_VectorInsertDynamic(
266+
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
267+
return *this;
268+
#else
269+
(void)rhs;
270+
throw runtime_error("joint matrix is not supported on host device.",
271+
PI_INVALID_DEVICE);
272+
#endif // __SYCL_DEVICE_ONLY__
273+
}
274+
275+
friend T operator+(const wi_element<T, NumRows, NumCols, Layout, Group> &lhs,
276+
const T &rhs) {
277+
#ifdef __SYCL_DEVICE_ONLY__
278+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) + rhs;
279+
#else
280+
(void)lhs;
281+
(void)rhs;
282+
throw runtime_error("joint matrix is not supported on host device.",
283+
PI_INVALID_DEVICE);
284+
#endif // __SYCL_DEVICE_ONLY__
285+
}
286+
287+
wi_element &operator+=(const T &rhs) {
288+
#ifdef __SYCL_DEVICE_ONLY__
289+
M.spvm = __spirv_VectorInsertDynamic(
290+
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) + rhs, idx);
291+
return *this;
292+
#else
293+
(void)rhs;
294+
throw runtime_error("joint matrix is not supported on host device.",
295+
PI_INVALID_DEVICE);
296+
#endif // __SYCL_DEVICE_ONLY__
297+
}
298+
299+
friend T operator-(const wi_element<T, NumRows, NumCols, Layout, Group> &lhs,
300+
const T &rhs) {
301+
#ifdef __SYCL_DEVICE_ONLY__
302+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) - rhs;
303+
#else
304+
(void)lhs;
305+
(void)rhs;
306+
throw runtime_error("joint matrix is not supported on host device.",
307+
PI_INVALID_DEVICE);
308+
#endif // __SYCL_DEVICE_ONLY__
309+
}
310+
311+
wi_element &operator-=(const T &rhs) {
312+
#ifdef __SYCL_DEVICE_ONLY__
313+
M.spvm = __spirv_VectorInsertDynamic(
314+
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) - rhs, idx);
315+
return *this;
316+
#else
317+
(void)rhs;
318+
throw runtime_error("joint matrix is not supported on host device.",
319+
PI_INVALID_DEVICE);
320+
#endif // __SYCL_DEVICE_ONLY__
321+
}
322+
323+
friend T operator*(const wi_element<T, NumRows, NumCols, Layout, Group> &lhs,
324+
const T &rhs) {
325+
#ifdef __SYCL_DEVICE_ONLY__
326+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) * rhs;
327+
#else
328+
(void)lhs;
329+
(void)rhs;
330+
throw runtime_error("joint matrix is not supported on host device.",
331+
PI_INVALID_DEVICE);
332+
#endif // __SYCL_DEVICE_ONLY__
333+
}
334+
251335
wi_element &operator*=(const T &rhs) {
252336
#ifdef __SYCL_DEVICE_ONLY__
253337
M.spvm = __spirv_VectorInsertDynamic(
@@ -259,7 +343,108 @@ class wi_element {
259343
PI_INVALID_DEVICE);
260344
#endif // __SYCL_DEVICE_ONLY__
261345
}
262-
// TODO: add other arithmetic operators
346+
347+
friend T operator/(const wi_element<T, NumRows, NumCols, Layout, Group> &lhs,
348+
const T &rhs) {
349+
#ifdef __SYCL_DEVICE_ONLY__
350+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) / rhs;
351+
#else
352+
(void)lhs;
353+
(void)rhs;
354+
throw runtime_error("joint matrix is not supported on host device.",
355+
PI_INVALID_DEVICE);
356+
#endif // __SYCL_DEVICE_ONLY__
357+
}
358+
359+
wi_element &operator/=(const T &rhs) {
360+
#ifdef __SYCL_DEVICE_ONLY__
361+
M.spvm = __spirv_VectorInsertDynamic(
362+
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) / rhs, idx);
363+
return *this;
364+
#else
365+
(void)rhs;
366+
throw runtime_error("joint matrix is not supported on host device.",
367+
PI_INVALID_DEVICE);
368+
#endif // __SYCL_DEVICE_ONLY__
369+
}
370+
371+
friend bool
372+
operator<(const wi_element<T, NumRows, NumCols, Layout, Group> &lhs,
373+
const T &rhs) {
374+
#ifdef __SYCL_DEVICE_ONLY__
375+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) < rhs;
376+
#else
377+
(void)lhs;
378+
(void)rhs;
379+
throw runtime_error("joint matrix is not supported on host device.",
380+
PI_INVALID_DEVICE);
381+
#endif // __SYCL_DEVICE_ONLY__
382+
}
383+
384+
friend bool
385+
operator<=(const wi_element<T, NumRows, NumCols, Layout, Group> &lhs,
386+
const T &rhs) {
387+
#ifdef __SYCL_DEVICE_ONLY__
388+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) <= rhs;
389+
#else
390+
(void)lhs;
391+
(void)rhs;
392+
throw runtime_error("joint matrix is not supported on host device.",
393+
PI_INVALID_DEVICE);
394+
#endif // __SYCL_DEVICE_ONLY__
395+
}
396+
397+
friend bool
398+
operator>(const wi_element<T, NumRows, NumCols, Layout, Group> &lhs,
399+
const T &rhs) {
400+
#ifdef __SYCL_DEVICE_ONLY__
401+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) > rhs;
402+
#else
403+
(void)lhs;
404+
(void)rhs;
405+
throw runtime_error("joint matrix is not supported on host device.",
406+
PI_INVALID_DEVICE);
407+
#endif // __SYCL_DEVICE_ONLY__
408+
}
409+
410+
friend bool
411+
operator>=(const wi_element<T, NumRows, NumCols, Layout, Group> &lhs,
412+
const T &rhs) {
413+
#ifdef __SYCL_DEVICE_ONLY__
414+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) >= rhs;
415+
#else
416+
(void)lhs;
417+
(void)rhs;
418+
throw runtime_error("joint matrix is not supported on host device.",
419+
PI_INVALID_DEVICE);
420+
#endif // __SYCL_DEVICE_ONLY__
421+
}
422+
423+
friend bool
424+
operator==(const wi_element<T, NumRows, NumCols, Layout, Group> &lhs,
425+
const T &rhs) {
426+
#ifdef __SYCL_DEVICE_ONLY__
427+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) == rhs;
428+
#else
429+
(void)lhs;
430+
(void)rhs;
431+
throw runtime_error("joint matrix is not supported on host device.",
432+
PI_INVALID_DEVICE);
433+
#endif // __SYCL_DEVICE_ONLY__
434+
}
435+
436+
friend bool
437+
operator!=(const wi_element<T, NumRows, NumCols, Layout, Group> &lhs,
438+
const T &rhs) {
439+
#ifdef __SYCL_DEVICE_ONLY__
440+
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) != rhs;
441+
#else
442+
(void)lhs;
443+
(void)rhs;
444+
throw runtime_error("joint matrix is not supported on host device.",
445+
PI_INVALID_DEVICE);
446+
#endif // __SYCL_DEVICE_ONLY__
447+
}
263448
};
264449

265450
template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,

0 commit comments

Comments
 (0)