diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 020441c071e0..90b5601304a2 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -23,89 +23,107 @@ #ifdef __SYCL_DEVICE_ONLY__ template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, int MemOperand = 0); template extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL( - T *Ptr, __spv::__spirv_JointMatrixINTEL *Object, + T *Ptr, __spv::__spirv_JointMatrixINTEL *Object, std::size_t Stride, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, int MemOperand = 0); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixUUMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixUSMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_JointMatrixSUMadINTEL( - __spv::__spirv_JointMatrixINTEL *A, - __spv::__spirv_JointMatrixINTEL *B, - __spv::__spirv_JointMatrixINTEL *C, + __spv::__spirv_JointMatrixINTEL *A, + __spv::__spirv_JointMatrixINTEL *B, + __spv::__spirv_JointMatrixINTEL *C, __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup); template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_CompositeConstruct(const T v); -template extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL( - __spv::__spirv_JointMatrixINTEL *); + __spv::__spirv_JointMatrixINTEL *); -template extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic( - __spv::__spirv_JointMatrixINTEL *, size_t i); + __spv::__spirv_JointMatrixINTEL *, size_t i); -template -extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * -__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, +extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * +__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, T val, size_t i); #ifndef __SPIRV_BUILTIN_DECLARATIONS__ diff --git a/sycl/include/CL/__spirv/spirv_types.hpp b/sycl/include/CL/__spirv/spirv_types.hpp index b3ce746ff46e..815d38b34934 100644 --- a/sycl/include/CL/__spirv/spirv_types.hpp +++ b/sycl/include/CL/__spirv/spirv_types.hpp @@ -112,7 +112,15 @@ enum class MatrixLayout : uint32_t { RowMajor = 0, ColumnMajor = 1, PackedA = 2, - PackedB = 3 + PackedB = 3, + Unused = 4 +}; + +enum class MatrixUse : uint32_t { + MatrixA = 0, + MatrixB = 1, + Accumulator = 2, + Unnecessary = 3 }; // TODO: replace the following W/A with a better solution when we have it. @@ -129,10 +137,13 @@ enum class MatrixLayout : uint32_t { // information to SPIRV translator. // The long term solution would be to introduce a matrix type in Clang and use // it instead of this member. -template +template struct __spirv_JointMatrixINTEL { - T (*Value)[R][C][static_cast(U) + 1][static_cast(S) + 1]; + T(*Value) + [R][C][static_cast(L) + 1][static_cast(S) + 1] + [static_cast(U) + 1]; }; } // namespace __spv diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp deleted file mode 100644 index 2459769e97b7..000000000000 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp +++ /dev/null @@ -1,448 +0,0 @@ -//===------------ matrix-aot-amx.hpp - SYCL matrix ------------*- C++ -*---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// ===--------------------------------------------------------------------=== // -/// -/// We provide new interfaces for matrix muliply in this patch: -/// 1. A new class called joint_matrix is introduced, and the user needs to -/// specify the type of the elements, sizes, and the memory layout. -/// -/// 2. joint_matrix_load is used for loading data from main memory to tiles of -/// AMX or kernel's local memory. -/// -/// 3. joint_matrix_store is used for storing data tiles of AMX or kernel's -/// local memory to main memory. -/// -/// 4. joint_matrix_mad is used for the matrix multiply and add function. -/// It performs the multiply operation on the matrices A and B, accumulates the -/// result with C and returns the result. -/// -/// The following operation can be realized with the interfaces: -/// C = A*B+C -/// 1. All cases where A(int8, any-size, row_major), B(int8, any-size, -/// packed_b), C(int32, any-size, row_major) -/// 2. All cases where A(bf16, any-size, row_major), B(bf16, any-size, -/// packed_b), C(float, any-size, row_major) -/// -/// -// ===--------------------------------------------------------------------=== // - -#pragma once - -#include -#include - -namespace sycl { -__SYCL_INLINE_VER_NAMESPACE(_V1) { -namespace ext { -namespace intel { -namespace detail { -template class submatrix { -public: - _tile1024i tile; - short rows, cols; -}; - -// TODO: we are adding it this way until sycl::dynamic_extent gets implemented. -constexpr size_t dynamic_extent = std::numeric_limits::max(); - -template struct elems_per_dword { - static constexpr size_t value = 1; -}; - -#define ELEMS_PER_DWORD(TYPE, NUM) \ - template <> struct elems_per_dword { \ - static constexpr size_t value = NUM; \ - }; - -ELEMS_PER_DWORD(int8_t, 4) -ELEMS_PER_DWORD(unsigned short, 2) - -} // namespace detail - -namespace experimental::matrix { -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL extern "C" _tile1024i -_tileloadd64_internal(short row, short col, char *buf, size_t stride); -SYCL_EXTERNAL extern "C" _tile1024i -_tdpbssd_internal(unsigned short m, unsigned short n, unsigned short k, - _tile1024i dst, _tile1024i src1, _tile1024i src2); -SYCL_EXTERNAL extern "C" _tile1024i -_tdpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, - _tile1024i dst, _tile1024i src1, _tile1024i src2); -SYCL_EXTERNAL extern "C" void _tilestored64_internal(short row, short col, - char *buf, size_t stride, - _tile1024i tile); -static _tile1024i tileloadd64_internal(short row, short col, char *buf, - size_t stride) { - return _tileloadd64_internal(row, col, buf, stride); -} -static _tile1024i tdpbssd_internal(unsigned short m, unsigned short n, - unsigned short k, _tile1024i dst, - _tile1024i src1, _tile1024i src2) { - return _tdpbssd_internal(m, n, k, dst, src1, src2); -} -static _tile1024i tdpbf16ps_internal(unsigned short m, unsigned short n, - unsigned short k, _tile1024i dst, - _tile1024i src1, _tile1024i src2) { - return _tdpbf16ps_internal(m, n, k, dst, src1, src2); -} -static void tilestored64_internal(short row, short col, char *buf, - size_t stride, _tile1024i tile) { - return _tilestored64_internal(row, col, buf, stride, tile); -} -#else -static _tile1024i tileloadd64_internal(short row, short col, char *buf, - size_t stride) { - return __builtin_ia32_tileloadd64_internal(row, col, buf, stride); -} -static _tile1024i tdpbssd_internal(unsigned short m, unsigned short n, - unsigned short k, _tile1024i dst, - _tile1024i src1, _tile1024i src2) { - return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2); -} -static _tile1024i tdpbf16ps_internal(unsigned short m, unsigned short n, - unsigned short k, _tile1024i dst, - _tile1024i src1, _tile1024i src2) { - return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2); -} -static void tilestored64_internal(short row, short col, char *buf, - size_t stride, _tile1024i tile) { - __builtin_ia32_tilestored64_internal(row, col, buf, stride, tile); -} -#endif - -enum class matrix_layout { row_major, col_major, packed_a, packed_b }; - -inline constexpr size_t tile_size = 16; - -template -struct joint_matrix { - joint_matrix(Group sg) {} - joint_matrix(Group sg, size_t Size) { - static_assert((NumRows != detail::dynamic_extent && - NumCols != detail::dynamic_extent), - "AMX implementation does not support dynamic allocation"); - } - joint_matrix(Group sg, size_t Rows, size_t Cols) { - static_assert((NumRows != detail::dynamic_extent && - NumCols != detail::dynamic_extent), - "AMX implementation does not support dynamic allocation"); - } -}; - -// This template specialization handles cases where matrix can't be accommodated -// by a tile. In this case, we create raw_storage for the matrix and the size -// is the multiply of (TILE*TILE*4). -template -struct joint_matrix< - Group, T, NumRows, NumCols, Layout, - typename std::enable_if::type> { -public: - // trows: Num of tiles in row. - // If T=int8, NumRows==33, trows should be 3=(33+15)/16 - static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size; - // tcols: Num of tiles in column. - static constexpr size_t tcols = - (NumCols * sizeof(T) / 4 + tile_size - 1) / tile_size; - // if T=int8, NumRows==33, NumCols==33*4, tile_size==16, then size of - // raw_storage should be 48*48*4. - // FIXME: Greedy Regalloc for tile seems has some limitation and currently we - // do tileload for (16,16*4) instead of varying shapes, so raw_storage's size - // is multiple of (16*16*4) - static constexpr size_t size = trows * tcols * tile_size * tile_size * 4; - // stride is aligned to T instead of int8 - static constexpr size_t stride = tcols * tile_size * 4 / sizeof(T); - int8_t raw_storage[size]; - static constexpr bool isSmall = false; - -public: - matrix_layout layout; - // We do zero-padding for matrix whose size is not fitted into tiles in ctor. - joint_matrix(Group sg) { memset(raw_storage, 0x00, size); } -}; - -// This template specialization handles cases where matrix can be put into a -// tile and users specify layout is packed_a or packed_b -template -struct joint_matrix< - Group, T, NumRows, NumCols, Layout, - typename std::enable_if<(NumRows <= tile_size) && - (NumCols * sizeof(T) / 4 <= tile_size)>::type> { -public: - static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size; - // tcols: Num of tiles in column. - static constexpr size_t tcols = - (NumCols * sizeof(T) / 4 + tile_size - 1) / tile_size; - static constexpr size_t size = trows * tcols * tile_size * tile_size * 4; - // stride is aligned to T instead of int8 - static constexpr size_t stride = tcols * tile_size * 4 / sizeof(T); - _tile1024i tile; - static constexpr bool isSmall = true; - matrix_layout layout; - // We do zero-padding for matrix whose size is not fitted into tiles in ctor. - joint_matrix(Group sg) {} -}; - -} // namespace experimental::matrix - -namespace detail { - -using namespace experimental; - -template -inline __SYCL_ALWAYS_INLINE static - typename std::enable_if<(NumRows > matrix::tile_size) || - (NumCols * sizeof(T) / 4 > matrix::tile_size), - void>::type - submatrix_load(detail::submatrix &sub_m, - matrix::joint_matrix jm, - uint32_t row, uint32_t col, size_t stride, - matrix::matrix_layout layout, bool shouldreload) { - uint32_t offset = (row * stride + col); - T *ptr = reinterpret_cast(jm.raw_storage); - ptr += offset; - stride *= sizeof(T); - sub_m.rows = matrix::tile_size; - sub_m.cols = matrix::tile_size * 4; - sub_m.tile = matrix::tileloadd64_internal( - sub_m.rows, sub_m.cols, reinterpret_cast(ptr), stride); -} - -template -inline __SYCL_ALWAYS_INLINE static - typename std::enable_if<(NumRows <= matrix::tile_size) && - (NumCols * sizeof(T) / 4 <= matrix::tile_size), - void>::type - submatrix_load(detail::submatrix &sub_m, - matrix::joint_matrix &jm, - uint32_t row, uint32_t col, size_t stride, - matrix::matrix_layout layout, bool shouldreload) { - if (shouldreload) { - // Force sub_m.tile's shape to be matrix::tile_size * - // matrix::tile_size * 4 - int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4]; - matrix::tilestored64_internal(NumRows, NumCols * sizeof(T), - reinterpret_cast(NewjmC), - matrix::tile_size * 4, jm.tile); - sub_m.rows = matrix::tile_size; - sub_m.cols = matrix::tile_size * 4; - sub_m.tile = matrix::tileloadd64_internal(sub_m.rows, sub_m.cols, - reinterpret_cast(NewjmC), - matrix::tile_size * 4); - return; - } - sub_m.rows = NumRows; - sub_m.cols = NumCols * sizeof(T); - sub_m.tile = jm.tile; -} - -// This handles cases where T1 is int8, T2 is int32. -inline __SYCL_ALWAYS_INLINE static void -submatrix_mad(detail::submatrix &sub_ma, - detail::submatrix &sub_mb, - detail::submatrix &sub_mc) { - sub_mc.tile = matrix::tdpbssd_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols, - sub_mc.tile, sub_ma.tile, sub_mb.tile); -} - -// This handles cases where T1 is int16(bfloat16), T2 is float. -inline __SYCL_ALWAYS_INLINE static void -submatrix_mad(detail::submatrix &sub_ma, - detail::submatrix &sub_mb, - detail::submatrix &sub_mc) { - sub_mc.tile = - matrix::tdpbf16ps_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols, - sub_mc.tile, sub_ma.tile, sub_mb.tile); -} - -template -inline __SYCL_ALWAYS_INLINE static - typename std::enable_if<(NumRows > matrix::tile_size) || - (NumCols * sizeof(T) / 4 > matrix::tile_size), - void>::type - submatrix_store(detail::submatrix &sub_m, - matrix::joint_matrix &jm, - uint32_t row, uint32_t col, size_t stride, - matrix::matrix_layout layout, bool shouldreload) { - uint32_t offset = (row * stride + col); - T *ptr = reinterpret_cast(jm.raw_storage); - ptr += offset; - stride *= sizeof(T); - matrix::tilestored64_internal(sub_m.rows, sub_m.cols, - reinterpret_cast(ptr), stride, - sub_m.tile); -} - -template -inline __SYCL_ALWAYS_INLINE static - typename std::enable_if<(NumRows <= matrix::tile_size) && - (NumCols * sizeof(T) / 4 <= matrix::tile_size), - void>::type - submatrix_store(detail::submatrix &sub_m, - matrix::joint_matrix &jm, - uint32_t row, uint32_t col, size_t stride, - matrix::matrix_layout layout, bool shouldreload) { - if (shouldreload) { - int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4]; - matrix::tilestored64_internal(matrix::tile_size, matrix::tile_size * 4, - reinterpret_cast(NewjmC), - matrix::tile_size * 4, sub_m.tile); - jm.tile = matrix::tileloadd64_internal(NumRows, NumCols * sizeof(T), - reinterpret_cast(NewjmC), - matrix::tile_size * 4); - return; - } - jm.tile = sub_m.tile; -} - -} // namespace detail - -namespace experimental::matrix { - -// This handles cases where matrix can't be accommodated by a tile -template -inline __SYCL_ALWAYS_INLINE typename std::enable_if< - (NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type -joint_matrix_load(Group sg, - joint_matrix &jm, - multi_ptr src, size_t stride, - matrix_layout layout) { - T *mem = src.get(); - // memcpy from mem to jm.raw_storage - for (int i = 0; i < NumRows; ++i) { - char *srcptr = reinterpret_cast(mem) + i * stride * sizeof(T); - char *dstptr = - reinterpret_cast(jm.raw_storage) + i * jm.stride * sizeof(T); - // TODO: we may reformat layout. - memcpy(dstptr, srcptr, NumCols * sizeof(T)); - } - jm.layout = layout; -} - -// This handles cases where matrix can be put into a tile -template -inline __SYCL_ALWAYS_INLINE - typename std::enable_if<(NumRows <= tile_size) && - (NumCols * sizeof(T) / 4 <= tile_size), - void>::type - joint_matrix_load(Group sg, - joint_matrix &jm, - multi_ptr src, size_t stride, - matrix_layout layout) { - T *mem = src.get(); - // tileload happens! - jm.tile = - tileloadd64_internal(NumRows, NumCols * sizeof(T), - reinterpret_cast(mem), stride * sizeof(T)); - jm.layout = layout; -} - -// This handles cases where matrix can't be accommodated by a tile -template -inline __SYCL_ALWAYS_INLINE typename std::enable_if< - (NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type -joint_matrix_store(Group sg, - joint_matrix &jm, - multi_ptr dst, size_t stride, - matrix_layout layout) { - T *mem = dst.get(); - for (int i = 0; i < NumRows; ++i) { - char *dstptr = reinterpret_cast(mem) + i * stride * sizeof(T); - char *srcptr = - reinterpret_cast(jm.raw_storage) + i * jm.stride * sizeof(T); - // TODO: we may reformat layout. - memcpy(dstptr, srcptr, NumCols * sizeof(T)); - } - return; -} - -// This handles cases where matrix can be put into a tile -template -inline __SYCL_ALWAYS_INLINE - typename std::enable_if<(NumRows <= tile_size) && - (NumCols * sizeof(T) / 4 <= tile_size), - void>::type - joint_matrix_store(Group sg, - joint_matrix &jm, - multi_ptr dst, size_t stride, - matrix_layout layout) { - T *mem = dst.get(); - // tilestore happens! - tilestored64_internal(NumRows, NumCols * sizeof(T), - reinterpret_cast(mem), stride * sizeof(T), - jm.tile); - return; -} - -template -inline __SYCL_ALWAYS_INLINE typename std::enable_if< - ((std::is_same::value && std::is_same::value) || - (std::is_same::value && - std::is_same::value)) && - (LayoutA == matrix_layout::row_major) && - (LayoutB == matrix_layout::packed_b) && - (LayoutC == matrix_layout::row_major), - joint_matrix>::type -joint_matrix_mad(Group sg, - joint_matrix &jmA, - joint_matrix &jmB, - joint_matrix &jmC) { - joint_matrix res(jmC); - constexpr size_t epd = detail::elems_per_dword::value; - // If A is large and C is small, in joint_matrix_load, we do memcpy for A, and - // we do tileload for C whose shape is not tile_size*tile_size*4. In - // joint_matrix_mad, we do tileload for A and shape is tile_size*tile_size*4. - // So we need to reshape C before we do dpbssd. - bool Cshouldreload = res.isSmall && !jmA.isSmall && !jmB.isSmall; - bool Ashouldreload = jmA.isSmall && !jmB.isSmall; - bool Bshouldreload = jmB.isSmall && !jmA.isSmall; - - for (int m = 0; m < res.trows; ++m) { - for (int n = 0; n < res.tcols; ++n) { - detail::submatrix sub_c; - - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - submatrix_load(sub_c, res, m * tile_size, n * tile_size, res.stride, - matrix_layout::row_major, Cshouldreload); - for (int k = 0; k < jmA.tcols; ++k) { // K->int8_t - detail::submatrix sub_a; - detail::submatrix sub_b; - submatrix_load(sub_a, jmA, m * tile_size, k * tile_size * epd, - jmA.stride, matrix_layout::packed_a, Ashouldreload); - // Assume we alreay in vnni format. - submatrix_load(sub_b, jmB, k * tile_size, n * tile_size * epd, - jmB.stride, matrix_layout::packed_b, Bshouldreload); - submatrix_mad(sub_a, sub_b, sub_c); - } - submatrix_store(sub_c, res, m * tile_size, n * tile_size, res.stride, - matrix_layout::row_major, Cshouldreload); - } - } - return res; -} - -} // namespace experimental::matrix -} // namespace intel -} // namespace ext -} // __SYCL_INLINE_VER_NAMESPACE(_V1) -} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp new file mode 100644 index 000000000000..f8ff4aeaa047 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp @@ -0,0 +1,656 @@ +//==------------------ matrix-jit-use.hpp - SYCL matrix ----------------*- C++ +//-*---==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // + +#pragma once + +#include +#include +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext { +namespace oneapi { +namespace experimental { +namespace matrix { + +// packed_a and packed_b will be replaced by packed once the use implementation +// is stable. +enum class layout { row_major, col_major, packed_a, packed_b, unused }; + +template struct spv_matrix_layout_traits { + static constexpr __spv::MatrixLayout value = __spv::MatrixLayout::Unused; +}; + +#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \ + template <> struct spv_matrix_layout_traits { \ + static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \ + }; + +SPV_MATRIX_LAYOUT_TRAITS(layout::row_major, __spv::MatrixLayout::RowMajor) +SPV_MATRIX_LAYOUT_TRAITS(layout::col_major, __spv::MatrixLayout::ColumnMajor) +SPV_MATRIX_LAYOUT_TRAITS(layout::packed_a, __spv::MatrixLayout::PackedA) +SPV_MATRIX_LAYOUT_TRAITS(layout::packed_b, __spv::MatrixLayout::PackedB) +SPV_MATRIX_LAYOUT_TRAITS(layout::unused, __spv::MatrixLayout::Unused) + +// unnecessary was introduced for backward compatibility. +// Once the use implementation is stable, "unnecessary" value will be omitted +enum class use { a, b, accumulator, unnecessary }; + +template struct spv_matrix_use_traits { + static constexpr __spv::MatrixUse value = __spv::MatrixUse::MatrixA; +}; + +#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \ + template <> struct spv_matrix_use_traits { \ + static constexpr __spv::MatrixUse value = SPV_USE; \ + }; + +SPV_MATRIX_USE_TRAITS(use::a, __spv::MatrixUse::MatrixA) +SPV_MATRIX_USE_TRAITS(use::b, __spv::MatrixUse::MatrixB) +SPV_MATRIX_USE_TRAITS(use::accumulator, __spv::MatrixUse::Accumulator) +SPV_MATRIX_USE_TRAITS(use::unnecessary, __spv::MatrixUse::Unnecessary) + +template struct spv_scope_traits {}; +template <> struct spv_scope_traits { + constexpr static auto value = __spv::Scope::Subgroup; +}; +template struct spv_scope_traits> { + constexpr static auto value = __spv::Scope::Workgroup; +}; + +template +class wi_data; +template +struct joint_matrix { +public: + __spv::__spirv_JointMatrixINTEL< + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_scope_traits::value, spv_matrix_use_traits::value> *spvm; + joint_matrix(Group sg) { +#ifndef __SYCL_DEVICE_ONLY__ + (void)sg; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + inline __SYCL_ALWAYS_INLINE wi_data + get_wi_data() { + return wi_data(*this); + } +}; + +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_load( + Group sg, + joint_matrix &res, + multi_ptr src, size_t stride, layout MemL) { +#ifdef __SYCL_DEVICE_ONLY__ + T *Ptr = src.get(); + switch (MemL) { + default: + assert(false && "Invalid Memory Layout!"); + case layout::row_major: + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); + break; + case layout::col_major: + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); + break; + case layout::packed_a: + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedA, + spv_scope_traits::value); + break; + case layout::packed_b: + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedB, + spv_scope_traits::value); + break; + } +#else + (void)sg; + (void)res; + (void)src; + (void)stride; + (void)MemL; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ +} + +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_store( + Group sg, + joint_matrix &src, + multi_ptr res, size_t stride, layout MemL) { +#ifdef __SYCL_DEVICE_ONLY__ + T *Ptr = res.get(); + switch (MemL) { + default: + assert(false && "Invalid Memory Layout!"); + case layout::row_major: + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); + break; + case layout::col_major: + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); + break; + case layout::packed_a: + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA, + spv_scope_traits::value); + break; + case layout::packed_b: + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB, + spv_scope_traits::value); + break; + } +#else + (void)sg; + (void)src; + (void)res; + (void)stride; + (void)MemL; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ +} + +template +inline __SYCL_ALWAYS_INLINE + joint_matrix + joint_matrix_mad( + Group sg, joint_matrix &mA, + joint_matrix &mB, + joint_matrix &mC) { +#ifdef __SYCL_DEVICE_ONLY__ + joint_matrix res(sg); + if constexpr (std::is_same::value && + std::is_same::value && + std::is_same::value) + res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else if constexpr (std::is_unsigned::value && std::is_unsigned::value) + res.spvm = __spirv_JointMatrixUUMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else if constexpr (std::is_signed::value && std::is_unsigned::value) + res.spvm = __spirv_JointMatrixSUMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else if constexpr (std::is_unsigned::value && std::is_signed::value) + res.spvm = __spirv_JointMatrixUSMadINTEL(mA.spvm, mB.spvm, mC.spvm); + else + res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm); + return res; +#else + (void)sg; + (void)mA; + (void)mB; + (void)mC; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ +} + +template +inline __SYCL_ALWAYS_INLINE void +joint_matrix_fill(Group sg, + joint_matrix &res, + const T2 v) { + // We kept the unused "sg" in joint_matrix_fill to match the other DPC++ + // functions + (void)sg; +#ifdef __SYCL_DEVICE_ONLY__ + res.spvm = + __spirv_CompositeConstruct::value, + spv_matrix_layout_traits::value>( + static_cast(v)); + +#else + (void)res; + (void)v; +#endif // __SYCL_DEVICE_ONLY__ +} + +template +class wi_element { + joint_matrix &M; + std::size_t idx; + +public: + wi_element(joint_matrix &Mat, + std::size_t i) + : M(Mat), idx(i) {} + operator T() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + explicit operator bool() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast(0); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + template wi_element &operator=(const T2 &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast(rhs), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element & + operator=(const wi_element &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + +#if __SYCL_DEVICE_ONLY__ +#define OP(op) \ + template wi_element &operator op##=(const T2 &rhs) { \ + M.spvm = __spirv_VectorInsertDynamic( \ + M.spvm, \ + static_cast(__spirv_VectorExtractDynamic(M.spvm, idx) \ + op static_cast(rhs)), \ + idx); \ + return *this; \ + } +#else // __SYCL_DEVICE_ONLY__ +#define OP(op) \ + template wi_element &operator op##=(const T2 &rhs) { \ + (void)rhs; \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_ERROR_INVALID_DEVICE); \ + } +#endif // __SYCL_DEVICE_ONLY__ + OP(+) + OP(-) + OP(*) + OP(/) +#undef OP +}; + +// Note that similarly to the other matrix functions, uint16_t is used here to +// represent bf16 type. Since the AMX and DPAS implementations don't support +// uint16_t, this interpretation is possible. This design choice was made before +// the introduction of SYCL experimental bfloat16 type. Our plan is to move +// towards using the SYCL bfloat16. But since it is still experimental, we will +// probably keep both uint16 interpretation and SYCL bfloat16. +template +class wi_element { + joint_matrix &M; + std::size_t idx; + +public: + wi_element(joint_matrix &Mat, + std::size_t i) + : M(Mat), idx(i) {} + operator uint16_t() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + explicit operator bool() { +#ifdef __SYCL_DEVICE_ONLY__ + return std::fabs(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx))) >= + std::numeric_limits::epsilon(); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator=( + const wi_element &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + // We use here the following functions for conversion (bf16=>fp32 and + // fp32=>bf16). This is a workaround until we are able to use + // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are + // supported in the CPU backend + static float make_fp32(uint16_t x) { + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; + } + + static uint16_t make_bf16(float x) { + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (uint16_t)*res; + } + +#if __SYCL_DEVICE_ONLY__ +#define OP(op) \ + wi_element &operator op##=(const uint16_t &rhs) { \ + M.spvm = __spirv_VectorInsertDynamic( \ + M.spvm, \ + make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx) \ + op make_fp32(rhs))), \ + idx); \ + return *this; \ + } +#else // __SYCL_DEVICE_ONLY__ +#define OP(op) \ + wi_element &operator op##=(const uint16_t &rhs) { \ + (void)rhs; \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_ERROR_INVALID_DEVICE); \ + } +#endif // __SYCL_DEVICE_ONLY__ + OP(+) + OP(-) + OP(*) + OP(/) +#undef OP + + template struct Converter { + static T2 convert(const T1 &from) { return static_cast(from); } + }; + + template struct Converter { + static uint16_t convert(const T &from) { return make_bf16(from); } + }; +#if __SYCL_DEVICE_ONLY__ +#define OP(input_type, type, op) \ + friend type operator op( \ + const wi_element &lhs, \ + const uint16_t &rhs) { \ + return Converter::convert(make_fp32( \ + __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) op make_fp32(rhs)); \ + } \ + friend type operator op( \ + const uint16_t &lhs, \ + const wi_element &rhs) { \ + return Converter::convert(make_fp32( \ + __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) op make_fp32(lhs)); \ + } +#else // __SYCL_DEVICE_ONLY__ +#define OP(input_type, type, op) \ + friend type operator op( \ + const wi_element &lhs, \ + const uint16_t &rhs) { \ + (void)lhs; \ + (void)rhs; \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_ERROR_INVALID_DEVICE); \ + } \ + friend type operator op( \ + const uint16_t &lhs, \ + const wi_element &rhs) { \ + (void)lhs; \ + (void)rhs; \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_ERROR_INVALID_DEVICE); \ + } +#endif // __SYCL_DEVICE_ONLY__ + OP(float, uint16_t, +) + OP(float, uint16_t, -) + OP(float, uint16_t, *) + OP(float, uint16_t, /) + OP(bool, bool, ==) + OP(bool, bool, !=) + OP(bool, bool, <) + OP(bool, bool, >) + OP(bool, bool, <=) + OP(bool, bool, >=) +#undef OP +}; + +template +class wi_element { + joint_matrix &M; + std::size_t idx; + +public: + wi_element(joint_matrix &Mat, + std::size_t i) + : M(Mat), idx(i) {} + operator sycl::ext::oneapi::experimental::bfloat16() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + explicit operator bool() { +#ifdef __SYCL_DEVICE_ONLY__ + return std::fabs(static_cast(__spirv_VectorExtractDynamic( + M.spvm, idx))) >= std::numeric_limits::epsilon(); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator=(const sycl::ext::oneapi::experimental::bfloat16 &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element & + operator=(const wi_element &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + +#if __SYCL_DEVICE_ONLY__ +#define OP(opassign, op) \ + wi_element &operator opassign( \ + const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ + M.spvm = __spirv_VectorInsertDynamic( \ + M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \ + return *this; \ + } +#else // __SYCL_DEVICE_ONLY__ +#define OP(opassign, op) \ + wi_element &operator opassign( \ + const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ + (void)rhs; \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_ERROR_INVALID_DEVICE); \ + } +#endif // __SYCL_DEVICE_ONLY__ + OP(+=, +) + OP(-=, -) + OP(*=, *) + OP(/=, /) +#undef OP + +#if __SYCL_DEVICE_ONLY__ +#define OP(type, op) \ + friend type operator op( \ + const wi_element &lhs, \ + const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ + return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \ + } \ + friend type operator op( \ + const sycl::ext::oneapi::experimental::bfloat16 &lhs, \ + const wi_element &rhs) { \ + return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \ + } + OP(sycl::ext::oneapi::experimental::bfloat16, +) + OP(sycl::ext::oneapi::experimental::bfloat16, -) + OP(sycl::ext::oneapi::experimental::bfloat16, *) + OP(sycl::ext::oneapi::experimental::bfloat16, /) +#undef OP +#define OP(type, op) \ + friend type operator op( \ + const wi_element &lhs, \ + const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ + return type{static_cast(__spirv_VectorExtractDynamic( \ + lhs.M.spvm, lhs.idx)) op static_cast(rhs)}; \ + } \ + friend type operator op( \ + const sycl::ext::oneapi::experimental::bfloat16 &lhs, \ + const wi_element &rhs) { \ + return type{static_cast(__spirv_VectorExtractDynamic( \ + rhs.M.spvm, rhs.idx)) op static_cast(lhs)}; \ + } + OP(bool, ==) + OP(bool, !=) + OP(bool, <) + OP(bool, >) + OP(bool, <=) + OP(bool, >=) +#undef OP +#else // __SYCL_DEVICE_ONLY__ +#define OP(type, op) \ + friend type operator op( \ + const wi_element &, \ + const sycl::ext::oneapi::experimental::bfloat16 &) { \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_ERROR_INVALID_DEVICE); \ + } \ + friend type operator op( \ + const sycl::ext::oneapi::experimental::bfloat16 &, \ + const wi_element &) { \ + throw runtime_error("joint matrix is not supported on host device.", \ + PI_ERROR_INVALID_DEVICE); \ + } + OP(sycl::ext::oneapi::experimental::bfloat16, +) + OP(sycl::ext::oneapi::experimental::bfloat16, -) + OP(sycl::ext::oneapi::experimental::bfloat16, *) + OP(sycl::ext::oneapi::experimental::bfloat16, /) + OP(bool, ==) + OP(bool, !=) + OP(bool, <) + OP(bool, >) + OP(bool, <=) + OP(bool, >=) +#undef OP +#endif // __SYCL_DEVICE_ONLY__ +}; + +template +class wi_data { + joint_matrix &M; + +public: + wi_data(joint_matrix &Mat) + : M(Mat) {} + size_t length() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + wi_element operator[](size_t i) { + return wi_element(M, i); + } +}; + +} // namespace matrix +} // namespace experimental +} // namespace oneapi +} // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index 3368919af943..edd8ab43d75f 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -38,6 +38,21 @@ SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::col_major, SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_a, __spv::MatrixLayout::PackedA) SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_b, __spv::MatrixLayout::PackedB) +enum class matrix_use { matrix_a, matrix_b, accumulator, unnecessary }; + +template struct spv_matrix_use_traits { + static constexpr __spv::MatrixUse value = __spv::MatrixUse::MatrixA; +}; + +#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \ + template <> struct spv_matrix_use_traits { \ + static constexpr __spv::MatrixUse value = SPV_USE; \ + }; + +SPV_MATRIX_USE_TRAITS(matrix_use::matrix_a, __spv::MatrixUse::MatrixA) +SPV_MATRIX_USE_TRAITS(matrix_use::matrix_b, __spv::MatrixUse::MatrixB) +SPV_MATRIX_USE_TRAITS(matrix_use::accumulator, __spv::MatrixUse::Accumulator) +SPV_MATRIX_USE_TRAITS(matrix_use::unnecessary, __spv::MatrixUse::Unnecessary) template struct spv_scope_traits {}; template <> struct spv_scope_traits { constexpr static auto value = __spv::Scope::Subgroup; @@ -57,7 +72,9 @@ template ::value> *spvm; + T, NumRows, NumCols, spv_matrix_layout_traits::value, + spv_scope_traits::value, + spv_matrix_use_traits::value> *spvm; joint_matrix(Group sg) { #ifndef __SYCL_DEVICE_ONLY__ (void)sg; @@ -85,32 +102,36 @@ joint_matrix_load(Group sg, default: assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: - res.spvm = - __spirv_JointMatrixLoadINTEL::value>( - Ptr, stride, __spv::MatrixLayout::RowMajor, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); break; case matrix_layout::col_major: - res.spvm = - __spirv_JointMatrixLoadINTEL::value>( - Ptr, stride, __spv::MatrixLayout::ColumnMajor, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); break; case matrix_layout::packed_a: - res.spvm = - __spirv_JointMatrixLoadINTEL::value>( - Ptr, stride, __spv::MatrixLayout::PackedA, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedA, + spv_scope_traits::value); break; case matrix_layout::packed_b: - res.spvm = - __spirv_JointMatrixLoadINTEL::value>( - Ptr, stride, __spv::MatrixLayout::PackedB, - spv_scope_traits::value); + res.spvm = __spirv_JointMatrixLoadINTEL< + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, stride, __spv::MatrixLayout::PackedB, + spv_scope_traits::value); break; } #else @@ -137,28 +158,36 @@ joint_matrix_store(Group sg, default: assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: - __spirv_JointMatrixStoreINTEL::value>( - Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor, - spv_scope_traits::value); + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>(Ptr, src.spvm, stride, + __spv::MatrixLayout::RowMajor, + spv_scope_traits::value); break; case matrix_layout::col_major: - __spirv_JointMatrixStoreINTEL::value>( - Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, - spv_scope_traits::value); + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>(Ptr, src.spvm, stride, + __spv::MatrixLayout::ColumnMajor, + spv_scope_traits::value); break; case matrix_layout::packed_a: - __spirv_JointMatrixStoreINTEL::value>( - Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA, - spv_scope_traits::value); + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>(Ptr, src.spvm, stride, + __spv::MatrixLayout::PackedA, + spv_scope_traits::value); break; case matrix_layout::packed_b: - __spirv_JointMatrixStoreINTEL::value>( - Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB, - spv_scope_traits::value); + __spirv_JointMatrixStoreINTEL< + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>(Ptr, src.spvm, stride, + __spv::MatrixLayout::PackedB, + spv_scope_traits::value); break; } #else @@ -216,10 +245,10 @@ joint_matrix_fill(Group sg, // functions (void)sg; #ifdef __SYCL_DEVICE_ONLY__ - res.spvm = - __spirv_CompositeConstruct::value>( - static_cast(v)); + res.spvm = __spirv_CompositeConstruct< + T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>(static_cast(v)); #else (void)res; diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp index 01279c0abca9..ecfad58259cc 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix.hpp @@ -16,15 +16,16 @@ #include +// the default is matrix-jit-use but existing tests in llvm-test-suite won't +// fail because we have the "unnecessary" use value #if (SYCL_EXT_ONEAPI_MATRIX == 1) -#if defined(__AMXTILE__) && defined(__AMXINT8__) && defined(__AMXBF16__) -#include -#endif -#endif -#if (SYCL_EXT_ONEAPI_MATRIX == 2) #include #include #endif +#if (SYCL_EXT_ONEAPI_MATRIX == 2) +#include +#include +#endif #if (SYCL_EXT_ONEAPI_MATRIX == 3) #include #endif diff --git a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp new file mode 100644 index 000000000000..01a6d5f78e0f --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp @@ -0,0 +1,423 @@ +//===---------- static-query-use.hpp - SYCL matrix ------------*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // +// This file implements the static query interface for the joint_matrix +// experimental extension. AMX, DPAS and different other TPUs support different +// logical sizes and types. The query interface is used to validate user code +// and inform them about supported types, sizes, scope, and layouts by the +// current implementation. Note that this query interface is a compile-time +// query, so there will be no runtime errors. The query interface provides +// three functionalities: +// 1- At compile time, inform the user whether a specific +// combination is valid or not. +// 2- Construct the matrices using a default shape +// if user does not provide a combination +// 3- General query interface for sizes, types, +// static/dynamic, scope. This is needed to void padding by the user, +// for tuning, and efficient code generation if used by a library. + +#pragma once + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext { +namespace oneapi { +namespace experimental::matrix { + +enum class tpu { + dpas, + amx, +}; +enum class matrix_type { + bf8, + bf16, + fp16, + fp19, // tfloat32 + fp32, + fp64, + sint2, + sint4, + sint8, + sint16, + sint32, + sint64, + uint2, + uint4, + uint8, + uint16, + uint32, + uint64 +}; + +enum class scope_t { sub_group, work_group }; + +template +struct tpu_params; + +#if __cplusplus >= 201703L +template +constexpr bool is_combination_valid_amx(int M, int N, int K) { + // is_same_v is a C++17 feature + if ((std::is_same_v && std::is_same_v && + std::is_same_v && M <= 16 && N <= 16 && K <= 64) || + (std::is_same_v && std::is_same_v && + std::is_same_v && M <= 16 && N <= 16 && K <= 64) || + (std::is_same_v && std::is_same_v && + std::is_same_v && M <= 16 && N <= 16 && K <= 64) || + (std::is_same_v && std::is_same_v && + std::is_same_v && M <= 16 && N <= 16 && K <= 64) || + // bf16 + (std::is_same_v && + std::is_same_v && std::is_same_v && + M <= 16 && N <= 16 && K <= 32)) + return true; + else + return false; +} + +template +constexpr bool are_types_valid_amx() { + if ((std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v && std::is_same_v)) + return true; + else + return false; +} +#endif + +// General query: +// types are not given, no default sizes and no implicit matrix construction +template +struct tpu_params { + static constexpr std::size_t defaultM = -1; // depends on the type + static constexpr std::size_t defaultN = -1; + static constexpr std::size_t defaultK = -1; + + bool dynamic_p = false; // should be true in future implementations because + // AMX hardware supports dynamic sizes + uint32_t numtiles = 8; + scope_t scope = scope_t::sub_group; + struct combination { + uint32_t max_msize; + uint32_t max_nsize; + uint32_t max_ksize; + matrix_type atype; + matrix_type btype; + matrix_type ctype; + uint32_t msize; + uint32_t nsize; + uint32_t ksize; + }; + using mt = matrix_type; + static constexpr combination combinations[] = { + {16, 16, 64, mt::sint8, mt::sint8, mt::sint32}, + {16, 16, 64, mt::sint8, mt::uint8, mt::sint32}, + {16, 16, 64, mt::uint8, mt::sint8, mt::sint32}, + {16, 16, 64, mt::uint8, mt::uint8, mt::sint32}, + {16, 16, 32, mt::bf16, mt::bf16, mt::fp32}}; + static constexpr int num_combinations = + sizeof(combinations) / sizeof(combination); +}; + +#if __cplusplus >= 201703L +// Sizes-only query +// Specialization for when only types are given, need to query only sizes +template +struct tpu_params && + !std::is_same_v && + !std::is_same_v)>::type> { + static_assert((are_types_valid_amx()), + "Invalid types for AMX, supported types are int8_t, uint8_t, " + "and bf16 (Note that unsigned short should be used in the" + "DPC++ code to implement bf16) "); + + // construct the matrices using the default sizes + static constexpr std::size_t defaultM = 16; + static constexpr std::size_t defaultN = 16; + static constexpr std::size_t defaultK = ((sizeof(Ta) == 1) ? 64 : 32); + + template + using joint_matrix_a = + joint_matrix; + template + using joint_matrix_b = + joint_matrix; + template + using joint_matrix_c = joint_matrix; + + bool dynamic_p = false; // should be true in future implementations because + // AMX hardware supports dynamic sizes + uint32_t numtiles = 8; + scope_t scope = scope_t::sub_group; + struct combination { + uint32_t max_msize; + uint32_t max_nsize; + uint32_t max_ksize; + matrix_type atype; + matrix_type btype; + matrix_type ctype; + uint32_t msize; + uint32_t nsize; + uint32_t ksize; + }; + static constexpr combination combinations[] = { + {16, 16, (sizeof(Ta) == 1) ? 64 : 32}}; + static constexpr int num_combinations = + sizeof(combinations) / sizeof(combination); +}; + +// Valid or not: +// Specialization when both types and sizes are given +template +struct tpu_params< + tpu::amx, Ta, Tb, Tc, M, N, K, + typename std::enable_if<( + !std::is_same_v && !std::is_same_v && + !std::is_same_v && M != 0 && N != 0 && K != 0)>::type> { + // Validate that parameters are supported + static_assert( + (M == 0 && N == 0 && K == 0) || + (is_combination_valid_amx(M, N, K)), + "Invalid parameters for AMX, query valid types and maximum sizes " + "using: tpu_params myparams; and then check out " + "myparams.combinations array"); + + // if combination is valid, construct the matrices + + static constexpr std::size_t defaultM = (M != 0) ? M : 16; + static constexpr std::size_t defaultN = (N != 0) ? N : 16; + static constexpr std::size_t defaultK = + (K != 0) ? K : ((sizeof(Ta) == 1) ? 64 : 32); + + template + using joint_matrix_a = + joint_matrix; + template + using joint_matrix_b = + joint_matrix; + template + using joint_matrix_c = joint_matrix; + + bool dynamic_p = false; // should be true in future implementations + // because AMX hardware supports dynamic sizes + uint32_t numtiles = 8; + scope_t scope = scope_t::sub_group; +}; + +// DPAS case +// The DPAS implementation supports the logical capability support of the HW +// So in this case, M, N, K sizes returned by the query represent the logical +// capabilities of the DPAS hardware. + +template +constexpr bool is_combination_valid_dpas(int M, int N, int K) { + if ((std::is_same_v && std::is_same_v && + std::is_same_v && (M == 1 || M == 2 || M == 4 || M == 8) && + N == 8 && K == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (M == 1 || M == 2 || M == 4 || M == 8) && + N == 8 && K == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (M == 1 || M == 2 || M == 4 || M == 8) && + N == 8 && K == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (M == 1 || M == 2 || M == 4 || M == 8) && + N == 8 && K == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (M == 1 || M == 2 || M == 4 || M == 8) && + N == 8 && K == 16) || + (std::is_same_v && + std::is_same_v && std::is_same_v && + (M == 1 || M == 2 || M == 4 || M == 8) && N == 8 && K == 16)) + return true; + else + return false; +} + +template +constexpr bool are_types_valid_dpas() { + if ((std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v && std::is_same_v)) + return true; + else + return false; +} +#endif + +// General Query +// specialization for when types are not given --> no default values +template +struct tpu_params { + static constexpr std::size_t defaultM = -1; // depends on the type + static constexpr std::size_t defaultN = -1; + static constexpr std::size_t defaultK = -1; + + bool dynamic_p = false; // no dynamic allocation on the GPU + uint32_t numtiles = -1; // does not apply for DPAS + scope_t scope = scope_t::sub_group; + + struct combination { + uint32_t max_msize; + uint32_t max_nsize; + uint32_t max_ksize; + matrix_type atype; + matrix_type btype; + matrix_type ctype; + uint32_t msize; + uint32_t nsize; + uint32_t ksize; + }; + using mt = matrix_type; + static constexpr combination combinations[] = { + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 1, 8, 32}, + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 2, 8, 32}, + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 4, 8, 32}, + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 8, 8, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 1, 8, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 2, 8, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 4, 8, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 8, 8, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 1, 8, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 2, 8, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 4, 8, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 8, 8, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 1, 8, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 2, 8, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 4, 8, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 8, 8, 32}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 1, 8, 16}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 2, 8, 16}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 4, 8, 16}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 8, 8, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 1, 8, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 2, 8, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 4, 8, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 8, 8, 16}, + }; + static constexpr int num_combinations = + sizeof(combinations) / sizeof(combination); +}; + +// Sizes-only query: +// Specialization for when only types are given, need to query only sizes + +#if __cplusplus >= 201703L +template +struct tpu_params && + !std::is_same_v && + !std::is_same_v)>::type> { + static_assert((are_types_valid_dpas()), + "Invalid types for DPAS, supported types are int8_t, uint8_t, " + "half, and bf16 (Note that unsigned short should be used in the" + "DPC++ code to implement bf16)"); + + // construct the matrices using the default sizes + + static constexpr std::size_t defaultM = 8; + static constexpr std::size_t defaultN = 8; + static constexpr std::size_t defaultK = ((sizeof(Ta) == 1) ? 32 : 16); + + template + using joint_matrix_a = + joint_matrix; + template + using joint_matrix_b = + joint_matrix; + template + using joint_matrix_c = joint_matrix; + + bool dynamic_p = false; // no dynamic allocation on the GPU + uint32_t numtiles = -1; // does not apply for DPAS + scope_t scope = scope_t::sub_group; + struct combination { + uint32_t max_msize; + uint32_t max_nsize; + uint32_t max_ksize; + matrix_type atype; + matrix_type btype; + matrix_type ctype; + uint32_t msize; + uint32_t nsize; + uint32_t ksize; + }; + using mt = matrix_type; + static constexpr combination combinations[] = { + // The types used in the initialization below are fake and not used. In + // this case, users already chose the types, they are only looking for the + // sizes + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 1, 8, (sizeof(Ta) == 1) ? 32 : 16}, + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 2, 8, (sizeof(Ta) == 1) ? 32 : 16}, + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 4, 8, (sizeof(Ta) == 1) ? 32 : 16}, + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 8, 8, (sizeof(Ta) == 1) ? 32 : 16}, + }; + static constexpr int num_combinations = + sizeof(combinations) / sizeof(combination); +}; + +// Valid or not: +// Specialization when both types and sizes are given +template +struct tpu_params< + tpu::dpas, Ta, Tb, Tc, M, N, K, + typename std::enable_if<((!std::is_same_v && M != 0))>::type> { + // Validate that parameters are supported + static_assert((M == 0 && N == 0 && K == 0) || + (is_combination_valid_dpas(M, N, K)), + "Invalid parameters for DPAS, query valid combinations " + "using: tpu_params myparams; and then check out " + "myparams.combinations array"); + + // if combination is valid, construct the matrices + static constexpr std::size_t defaultM = (M != 0) ? M : 8; + static constexpr std::size_t defaultN = (N != 0) ? N : 8; + static constexpr std::size_t defaultK = + (K != 0) ? K : ((sizeof(Ta) == 1) ? 32 : 16); + + template + using joint_matrix_a = + joint_matrix; + template + using joint_matrix_b = + joint_matrix; + template + using joint_matrix_c = joint_matrix; + + bool dynamic_p = false; // no dynamic allocation on the GPU + uint32_t numtiles = -1; // does not apply for DPAS + scope_t scope = scope_t::sub_group; +}; +#endif +} // namespace experimental::matrix +} // namespace oneapi +} // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/test/matrix/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/matrix-bf16-test-SG-16.cpp index b0113e2c3074..89e34eaf59cb 100644 --- a/sycl/test/matrix/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/matrix-bf16-test-SG-16.cpp @@ -1,7 +1,6 @@ -// RUN: %clangxx -fsycl -O2 %s -o %t.out -#include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) +// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -180,4 +179,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/matrix-bf16-test.cpp b/sycl/test/matrix/matrix-bf16-test.cpp index f871571b4063..32633c240a70 100644 --- a/sycl/test/matrix/matrix-bf16-test.cpp +++ b/sycl/test/matrix/matrix-bf16-test.cpp @@ -1,7 +1,6 @@ -// RUN: %clangxx -fsycl -O2 %s -o %t.out -#include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) +// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -180,4 +179,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/matrix-amx-bf16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test-use.cpp similarity index 62% rename from sycl/test/matrix/matrix-amx-bf16-test.cpp rename to sycl/test/matrix/matrix-bfloat16-test-use.cpp index 450d3dc44b40..1a8b10172101 100644 --- a/sycl/test/matrix/matrix-amx-bf16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-use.cpp @@ -1,16 +1,16 @@ -// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -march=sapphirerapids -fsycl -O2 %s -o %t.out -#include -#if (SYCL_EXT_ONEAPI_MATRIX == 1) +// RUN: %clangxx -fsycl -O2 %s -o %t.out #include +#include -using namespace sycl; -using namespace sycl::ext::intel; -using namespace sycl::ext::intel::experimental::matrix; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; -#define TILE_SZ 16 -#define TM (3 * TILE_SZ - 1) -#define TN (3 * TILE_SZ - 1) -#define TK (9 * TILE_SZ + 2) +static constexpr auto TILE_SZ = 16; +static constexpr auto TM = TILE_SZ - 1; +static constexpr auto TN = TILE_SZ - 1; +static constexpr auto TK = 2 * TILE_SZ - 2; + +static constexpr auto SG_SZ = 16; template struct big_matrix { public: @@ -36,20 +36,19 @@ void matrix_multiply(big_matrix &C, assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2); size_t NDRangeM = M / TM; size_t NDRangeN = N / TN; - buffer bufA(A.get_data(), range<2>(M, K)); - buffer bufB(B.get_data(), range<2>(K, N)); - buffer bufC((float *)C.get_data(), range<2>(M, N)); + sycl::buffer bufA(A.get_data(), sycl::range<2>(M, K)); + sycl::buffer bufB(B.get_data(), sycl::range<2>(K, N)); + sycl::buffer bufC((float *)C.get_data(), sycl::range<2>(M, N)); - queue q; - q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); + sycl::queue q; + q.submit([&](sycl::handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN}, {1, 1}), - [ accA, accB, accC, M, N, K ](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(1)]] + sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, K](sycl::nd_item<2> spmd_item) { // The submatrix API has to be accessed by all the workitems in a @@ -57,45 +56,37 @@ void matrix_multiply(big_matrix &C, // code divergence between the workitems const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx; - const auto sg_starty = global_idy; + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a( - sg); + sycl::ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); // For B, since current implementation does not support non-packed // layout, users need to specify the updated VNNI sizes along with // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). - joint_matrix - sub_b(sg); - joint_matrix sub_c(sg); - - // Only the leader perform AMX computation. - if (spmd_item.get_local_id(1) % TILE_SZ) - return; - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + - sg_starty * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { // K->int8_t + sg_starty / SG_SZ * TN, + N, layout::row_major); + for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assume we alreay in vnni format. + K, layout::row_major); + // Assuming B data is already in VNNI format. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty * TN * 2, - N * 2, matrix_layout::packed_b); + sg_starty / SG_SZ * TN * 2, + N * 2, layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + - sg_starty * TN, - N, matrix_layout::row_major); + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } @@ -103,8 +94,10 @@ void matrix_multiply(big_matrix &C, static constexpr size_t MATRIX_M = TM * 2; static constexpr size_t MATRIX_N = TN * 2; static constexpr size_t MATRIX_K = TK * 2; -unsigned short A[MATRIX_M][MATRIX_K]; -unsigned short B[MATRIX_K / 2][MATRIX_N * 2]; +bfloat16 A[MATRIX_M][MATRIX_K]; +bfloat16 B[MATRIX_K / 2][MATRIX_N * 2]; +unsigned short Aref[MATRIX_M][MATRIX_K]; +unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2]; float C[MATRIX_M][MATRIX_N]; float D[MATRIX_M][MATRIX_N]; @@ -142,12 +135,16 @@ void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, int main() { for (int i = 0; i < MATRIX_M; i++) { for (int j = 0; j < MATRIX_K; j++) { - A[i][j] = make_bf16(1.0f * (i + j)); + // Ee create bfloat16 from unsigned short since float-to-bfloat's + // conversion is not allowed. + A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j))); + Aref[i][j] = make_bf16(1.0f * (i + j)); } } for (int i = 0; i < MATRIX_K / 2; i++) { for (int j = 0; j < MATRIX_N * 2; j++) { - B[i][j] = make_bf16(2.0f * i + 3.0f * j); + B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j))); + Bref[i][j] = make_bf16(2.0f * i + 3.0f * j); } } for (int i = 0; i < MATRIX_M; i++) { @@ -159,11 +156,10 @@ int main() { big_matrix MC((float *)&C); big_matrix MD((float *)&D); - big_matrix MA((unsigned short *)&A); - big_matrix MB( - (unsigned short *)&B); + big_matrix MA((bfloat16 *)&A); + big_matrix MB((bfloat16 *)&B); matrix_multiply(MC, MA, MB); - matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M, MATRIX_N, MATRIX_K / 2); bool res = true; @@ -189,4 +185,3 @@ int main() { std::cout << "\n"; } } -#endif diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index 384714adef3d..065895fd5498 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -1,7 +1,6 @@ -// RUN: %clangxx -fsycl -O2 %s -o %t.out -#include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) +// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 #include +#include using namespace sycl::ext::oneapi::experimental::matrix; using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; @@ -188,4 +187,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 6fb68c12867c..2253654a8d93 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -1,8 +1,7 @@ -// RUN: %clangxx -fsycl -O2 %s -o %t.out +// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 -#include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -174,4 +173,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/matrix-int8-test-SG-16.cpp index 170abc5490ad..0e9855cb2457 100644 --- a/sycl/test/matrix/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/matrix-int8-test-SG-16.cpp @@ -1,7 +1,6 @@ -// RUN: %clangxx -fsycl -O2 %s -o %t.out -#include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) +// RUN: %clangxx -fsycl -O2 %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=1 #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -165,4 +164,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/matrix-amx-int8-test.cpp b/sycl/test/matrix/matrix-int8-test-use.cpp similarity index 74% rename from sycl/test/matrix/matrix-amx-int8-test.cpp rename to sycl/test/matrix/matrix-int8-test-use.cpp index 0c383bf2b7b4..dde82891291c 100644 --- a/sycl/test/matrix/matrix-amx-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test-use.cpp @@ -1,17 +1,22 @@ -// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -march=sapphirerapids -fsycl -O2 %s -o %t.out -#include -#if (SYCL_EXT_ONEAPI_MATRIX == 1) +// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s + +// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_4_3_0 = type { [12 x [48 x [5 x [4 x [1 x i8]]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_4_3_2 = type { [12 x [12 x [5 x [4 x [3 x i32]]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_4_3_1 = type { [48 x [12 x [5 x [4 x [2 x i8]]]]] addrspace(4)* } + #include +#include using namespace sycl; -using namespace sycl::ext::intel; -using namespace sycl::ext::intel::experimental::matrix; +using namespace sycl::ext::oneapi::experimental::matrix; #define TILE_SZ 16 -#define TM (4 * TILE_SZ - 4) -#define TN (4 * TILE_SZ - 4) +#define TM (TILE_SZ - 4) +#define TN (TILE_SZ - 4) #define TK (4 * TILE_SZ - 16) +#define SG_SZ 16 + template struct big_matrix { public: T *mat; @@ -47,9 +52,9 @@ void matrix_multiply(big_matrix &C, auto accB = bufB.get_access(cgh); cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN}, {1, 1}), - [ accA, accB, accC, M, N, K ](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(1)]] + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { // The submatrix API has to be accessed by all the workitems in a @@ -57,44 +62,36 @@ void matrix_multiply(big_matrix &C, // code divergence between the workitems const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx; - const auto sg_starty = global_idy; + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + joint_matrix sub_a(sg); // For B, since current implementation does not support non-packed // layout, users need to specify the updated VNNI sizes along with // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). - joint_matrix - sub_b(sg); - joint_matrix sub_c(sg); - - // Only the leader perform AMX computation. - if (spmd_item.get_local_id(1) % TILE_SZ) - return; + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { // K->int8_t + joint_matrix_fill(sg, sub_c, 0); + for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::packed_a); - // Assume we alreay in vnni format. + K, layout::row_major); + // Assuming B data is already in VNNI format. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty * TN * 4, - N * 4, matrix_layout::packed_b); + sg_starty / SG_SZ * TN * 4, + N * 4, layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + - sg_starty * TN, - N, matrix_layout::row_major); + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } @@ -137,8 +134,8 @@ int main() { } for (int i = 0; i < MATRIX_M; i++) { for (int j = 0; j < MATRIX_N; j++) { - C[i][j] = 1; - D[i][j] = 1; + C[i][j] = 0; + D[i][j] = 0; } } @@ -173,4 +170,3 @@ int main() { std::cout << "\n"; } } -#endif diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index 891d37aec696..b3edeea70258 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -1,12 +1,11 @@ -// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s +// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s -// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* } -// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* } -// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_3 = type { [12 x [48 x [1 x [4 x [4 x i8]]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3_3 = type { [12 x [12 x [1 x [4 x [4 x i32]]]]] addrspace(4)* } +// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3_3 = type { [48 x [12 x [4 x [4 x [4 x i8]]]]] addrspace(4)* } -#include -#if (SYCL_EXT_ONEAPI_MATRIX == 2) #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -167,4 +166,3 @@ int main() { std::cout << "\n"; } } -#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) diff --git a/sycl/test/matrix/query.cpp b/sycl/test/matrix/query.cpp index 544f2319aa08..838ae4026460 100644 --- a/sycl/test/matrix/query.cpp +++ b/sycl/test/matrix/query.cpp @@ -1,4 +1,4 @@ -// RUN: %clangxx -fsycl -o query %s +// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX=1 -fsycl -o query %s #include #include