@@ -10,12 +10,6 @@ static float make_fp32(uint16_t x) {
1010 return *res;
1111}
1212
13- static uint16_t make_bf16 (float x) {
14- int *res = reinterpret_cast <int *>(&x);
15- *res = *res >> 16 ;
16- return (uint16_t )*res;
17- }
18-
1913template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
2014public:
2115 T *mat;
@@ -40,7 +34,7 @@ void assert_ops_ref(
4034template <typename T, size_t M, size_t N>
4135void matrix_verify_add (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
4236 const float ref) {
43- buffer<unsigned short , 2 > bufA (A.get_data (), range<2 >(M, N));
37+ buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
4438
4539 q.submit ([&](handler &cgh) {
4640 auto accA = bufA.get_access <access::mode::read_write>(cgh);
@@ -55,12 +49,13 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
5549 sub_group sg = spmd_item.get_sub_group ();
5650 joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
5751
58- joint_matrix_fill (sg, sub_a, make_bf16 (5.0 ));
52+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
5953
6054 auto wi_slice_a = get_wi_data (sg, sub_a);
6155 for (int i = 0 ; i < wi_slice_a.length (); i++) {
62- wi_slice_a[i] = wi_slice_a[i] + make_bf16 (2 );
56+ wi_slice_a[i] = wi_slice_a[i] + bfloat16 (2 );
6357 }
58+
6459 ext::intel::experimental::matrix::joint_matrix_store (
6560 sg, sub_a,
6661 accA.get_pointer () + (sg_startx * TM) * N +
@@ -74,7 +69,7 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
7469template <typename T, size_t M, size_t N>
7570void matrix_verify_sub (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
7671 const float ref) {
77- buffer<unsigned short , 2 > bufA (A.get_data (), range<2 >(M, N));
72+ buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
7873
7974 q.submit ([&](handler &cgh) {
8075 auto accA = bufA.get_access <access::mode::read_write>(cgh);
@@ -89,11 +84,11 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
8984 sub_group sg = spmd_item.get_sub_group ();
9085 joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
9186
92- joint_matrix_fill (sg, sub_a, make_bf16 (5.0 ));
87+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
9388
9489 auto wi_slice_a = get_wi_data (sg, sub_a);
9590 for (int i = 0 ; i < wi_slice_a.length (); i++) {
96- wi_slice_a[i] = wi_slice_a[i] - make_bf16 (2 );
91+ wi_slice_a[i] = wi_slice_a[i] - bfloat16 (2 );
9792 }
9893 ext::intel::experimental::matrix::joint_matrix_store (
9994 sg, sub_a,
@@ -108,7 +103,7 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
108103template <typename T, size_t M, size_t N>
109104void matrix_verify_mul (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
110105 const float ref) {
111- buffer<unsigned short , 2 > bufA (A.get_data (), range<2 >(M, N));
106+ buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
112107
113108 q.submit ([&](handler &cgh) {
114109 auto accA = bufA.get_access <access::mode::read_write>(cgh);
@@ -122,11 +117,11 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
122117
123118 sub_group sg = spmd_item.get_sub_group ();
124119 joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
125- joint_matrix_fill (sg, sub_a, make_bf16 (5.0 ));
120+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
126121
127122 auto wi_slice_a = get_wi_data (sg, sub_a);
128123 for (int i = 0 ; i < wi_slice_a.length (); i++) {
129- wi_slice_a[i] = wi_slice_a[i] * make_bf16 (3.0 );
124+ wi_slice_a[i] = wi_slice_a[i] * bfloat16 (3.0 );
130125 }
131126 ext::intel::experimental::matrix::joint_matrix_store (
132127 sg, sub_a,
@@ -141,7 +136,7 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
141136template <typename T, size_t M, size_t N>
142137void matrix_verify_div (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
143138 const float ref) {
144- buffer<unsigned short , 2 > bufA (A.get_data (), range<2 >(M, N));
139+ buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
145140
146141 q.submit ([&](handler &cgh) {
147142 auto accA = bufA.get_access <access::mode::read_write>(cgh);
@@ -156,11 +151,11 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
156151 sub_group sg = spmd_item.get_sub_group ();
157152 joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
158153
159- joint_matrix_fill (sg, sub_a, make_bf16 (4.0 ));
154+ joint_matrix_fill (sg, sub_a, bfloat16 (4.0 ));
160155
161156 auto wi_slice_a = get_wi_data (sg, sub_a);
162157 for (int i = 0 ; i < wi_slice_a.length (); i++) {
163- wi_slice_a[i] = wi_slice_a[i] / make_bf16 (2.0 );
158+ wi_slice_a[i] = wi_slice_a[i] / bfloat16 (2.0 );
164159 }
165160 ext::intel::experimental::matrix::joint_matrix_store (
166161 sg, sub_a,
@@ -175,7 +170,7 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
175170template <typename T, size_t M, size_t N>
176171void matrix_verify_logic (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
177172 const float ref) {
178- buffer<unsigned short , 2 > bufA (A.get_data (), range<2 >(M, N));
173+ buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
179174
180175 q.submit ([&](handler &cgh) {
181176 auto accA = bufA.get_access <access::mode::read_write>(cgh);
@@ -189,26 +184,26 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
189184 sub_group sg = spmd_item.get_sub_group ();
190185 joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
191186
192- joint_matrix_fill (sg, sub_a, make_bf16 (5.0 ));
187+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
193188
194189 auto wi_slice_a = get_wi_data (sg, sub_a);
195190 for (int i = 0 ; i < wi_slice_a.length (); i++) {
196191 if (wi_slice_a[i]) {
197- if (wi_slice_a[i] > make_bf16 (2.0 ) ||
198- wi_slice_a[i] >= make_bf16 (2.0 ) ||
199- wi_slice_a[i] < make_bf16 (2.0 ) ||
200- wi_slice_a[i] <= make_bf16 (2.0 )) {
201- T val = (wi_slice_a[i] != make_bf16 (2.0 )) ? wi_slice_a[i]
202- : make_bf16 (2.0 );
203- val = make_bf16 (make_fp32 (val) - static_cast <float >(1 ));
204- val = make_bf16 (make_fp32 (val) + static_cast <float >(1 ));
205- if (wi_slice_a[i] == make_bf16 (2.0 )) {
206- val = make_bf16 (make_fp32 (val) - static_cast <float >(2 ));
207- val = make_bf16 (make_fp32 (val) * static_cast <float >(3 ));
208- val = make_bf16 (make_fp32 (val) / static_cast <float >(2 ));
192+ if (wi_slice_a[i] > bfloat16 (2.0 ) ||
193+ wi_slice_a[i] >= bfloat16 (2.0 ) ||
194+ wi_slice_a[i] < bfloat16 (2.0 ) ||
195+ wi_slice_a[i] <= bfloat16 (2.0 )) {
196+ T val = (wi_slice_a[i] != bfloat16 (2.0 )) ? wi_slice_a[i]
197+ : bfloat16 (2.0 );
198+ val = bfloat16 (make_fp32 (val) - static_cast <float >(1 ));
199+ val = bfloat16 (make_fp32 (val) + static_cast <float >(1 ));
200+ if (wi_slice_a[i] == bfloat16 (2.0 )) {
201+ val = bfloat16 (make_fp32 (val) - static_cast <float >(2 ));
202+ val = bfloat16 (make_fp32 (val) * static_cast <float >(3 ));
203+ val = bfloat16 (make_fp32 (val) / static_cast <float >(2 ));
209204
210205 } else {
211- val = make_bf16 (make_fp32 (val) + static_cast <float >(2 ));
206+ val = bfloat16 (make_fp32 (val) + static_cast <float >(2 ));
212207 }
213208 wi_slice_a[i] = val;
214209 }
@@ -226,7 +221,7 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
226221
227222static constexpr size_t MATRIX_M = TM * 2 ;
228223static constexpr size_t MATRIX_N = TN * 2 ;
229- unsigned short A[MATRIX_M][MATRIX_N];
224+ bfloat16 A[MATRIX_M][MATRIX_N];
230225float D[MATRIX_M][MATRIX_N];
231226
232227void matrix_ops_ref (float *D, int M, int N) {
@@ -240,18 +235,18 @@ void matrix_ops_ref(float *D, int M, int N) {
240235int main () {
241236
242237 big_matrix<float , MATRIX_M, MATRIX_N> MD ((float *)&D);
243- big_matrix<unsigned short , MATRIX_M, MATRIX_N> MA ((unsigned short *)&A);
238+ big_matrix<bfloat16 , MATRIX_M, MATRIX_N> MA ((bfloat16 *)&A);
244239
245240 size_t NDRangeM = MATRIX_M / TM;
246241 size_t NDRangeN = MATRIX_N / TN;
247242 queue q;
248243 nd_range<2 > r ({NDRangeM, NDRangeN * SG_SZ}, {1 , 1 * SG_SZ});
249244
250- matrix_verify_add<unsigned short , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
251- matrix_verify_sub<unsigned short , MATRIX_M, MATRIX_N>(q, MA, r, 3.0 );
252- matrix_verify_mul<unsigned short , MATRIX_M, MATRIX_N>(q, MA, r, 15.0 );
253- matrix_verify_div<unsigned short , MATRIX_M, MATRIX_N>(q, MA, r, 2.0 );
254- matrix_verify_logic<unsigned short , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
245+ matrix_verify_add<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
246+ matrix_verify_sub<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 3.0 );
247+ matrix_verify_mul<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 15.0 );
248+ matrix_verify_div<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 2.0 );
249+ matrix_verify_logic<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
255250
256251 return 0 ;
257252}
0 commit comments