Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 0 additions & 106 deletions sycl/test/sub_group/common_ocl.cpp

This file was deleted.

54 changes: 39 additions & 15 deletions sycl/test/sub_group/generic-shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
#include "helper.hpp"
#include <CL/sycl.hpp>
#include <complex>
template <typename T>
class pointer_kernel;
template <typename T> class pointer_kernel;

using namespace cl::sycl;

Expand Down Expand Up @@ -59,8 +58,9 @@ void check_pointer(queue &Queue, size_t G = 240, size_t L = 60) {
/* Save GID+SGID */
acc_down[NdItem.get_global_id()] = SG.shuffle_down(ptr, sgid);

/* Save GID XOR SGID */
acc_xor[NdItem.get_global_id()] = SG.shuffle_xor(ptr, sgid);
/* Save GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
acc_xor[NdItem.get_global_id()] =
SG.shuffle_xor(ptr, sgid % SG.get_max_local_range()[0]);
});
});
auto acc = buf.template get_access<access::mode::read_write>();
Expand All @@ -71,30 +71,44 @@ void check_pointer(queue &Queue, size_t G = 240, size_t L = 60) {

size_t sg_size = sgsizeacc[0];
int SGid = 0;
int SGLid = 0;
int SGBeginGid = 0;
for (int j = 0; j < G; j++) {
if (j % L % sg_size == 0) {
SGid++;
SGLid = 0;
SGBeginGid = j;
}
if (j % L == 0) {
SGid = 0;
SGLid = 0;
SGBeginGid = j;
}

/*GID of middle element in every subgroup*/
exit_if_not_equal(acc[j], static_cast<T *>(0x0) + (j / L * L + SGid * sg_size + sg_size / 2),
exit_if_not_equal(acc[j],
static_cast<T *>(0x0) +
(j / L * L + SGid * sg_size + sg_size / 2),
"shuffle");

/* Value GID+SGID for all element except last SGID in SG*/
if (j % L % sg_size + SGid < sg_size && j % L + SGid < L) {
exit_if_not_equal(acc_down[j], static_cast<T *>(0x0) + (j + SGid), "shuffle_down");
exit_if_not_equal(acc_down[j], static_cast<T *>(0x0) + (j + SGid),
"shuffle_down");
}

/* Value GID-SGID for all element except first SGID in SG*/
if (j % L % sg_size >= SGid) {
exit_if_not_equal(acc_up[j], static_cast<T *>(0x0) + (j - SGid), "shuffle_up");
exit_if_not_equal(acc_up[j], static_cast<T *>(0x0) + (j - SGid),
"shuffle_up");
}

/* GID XOR SGID */
exit_if_not_equal(acc_xor[j], static_cast<T *>(0x0) + (j ^ SGid), "shuffle_xor");
/* Value GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
exit_if_not_equal(acc_xor[j],
static_cast<T *>(0x0) +
(SGBeginGid + (SGLid ^ (SGid % sg_size))),
"shuffle_xor");
SGLid++;
}
} catch (exception e) {
std::cout << "SYCL exception caught: " << e.what();
Expand Down Expand Up @@ -145,8 +159,9 @@ void check_struct(queue &Queue, Generator &Gen, size_t G = 240, size_t L = 60) {
/* Save GID+SGID */
acc_down[NdItem.get_global_id()] = SG.shuffle_down(val, sgid);

/* Save GID XOR SGID */
acc_xor[NdItem.get_global_id()] = SG.shuffle_xor(val, sgid);
/* Save GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
acc_xor[NdItem.get_global_id()] =
SG.shuffle_xor(val, sgid % SG.get_max_local_range()[0]);
});
});
auto acc = buf.template get_access<access::mode::read_write>();
Expand All @@ -157,17 +172,23 @@ void check_struct(queue &Queue, Generator &Gen, size_t G = 240, size_t L = 60) {

size_t sg_size = sgsizeacc[0];
int SGid = 0;
int SGLid = 0;
int SGBeginGid = 0;
for (int j = 0; j < G; j++) {
if (j % L % sg_size == 0) {
SGid++;
SGLid = 0;
SGBeginGid = j;
}
if (j % L == 0) {
SGid = 0;
SGLid = 0;
SGBeginGid = j;
}

/*GID of middle element in every subgroup*/
exit_if_not_equal(acc[j], values[j / L * L + SGid * sg_size + sg_size / 2],
"shuffle");
exit_if_not_equal(
acc[j], values[j / L * L + SGid * sg_size + sg_size / 2], "shuffle");

/* Value GID+SGID for all element except last SGID in SG*/
if (j % L % sg_size + SGid < sg_size && j % L + SGid < L) {
Expand All @@ -179,8 +200,11 @@ void check_struct(queue &Queue, Generator &Gen, size_t G = 240, size_t L = 60) {
exit_if_not_equal(acc_up[j], values[j - SGid], "shuffle_up");
}

/* GID XOR SGID */
exit_if_not_equal(acc_xor[j], values[j ^ SGid], "shuffle_xor");
/* Value GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
exit_if_not_equal(acc_xor[j],
values[SGBeginGid + (SGLid ^ (SGid % sg_size))],
"shuffle_xor");
SGLid++;
}
} catch (exception e) {
std::cout << "SYCL exception caught: " << e.what();
Expand Down
42 changes: 30 additions & 12 deletions sycl/test/sub_group/shuffle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

#include "helper.hpp"
#include <CL/sycl.hpp>
template <typename T, int N>
class sycl_subgr;
template <typename T, int N> class sycl_subgr;

using namespace cl::sycl;

Expand Down Expand Up @@ -66,8 +65,9 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {
acc_up[NdItem.get_global_id()] = SG.shuffle_up(vwggid, sgid);
/* Save GID+SGID */
acc_down[NdItem.get_global_id()] = SG.shuffle_down(vwggid, sgid);
/* Save GID XOR SGID */
acc_xor[NdItem.get_global_id()] = SG.shuffle_xor(vwggid, sgid);
/* Save GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
acc_xor[NdItem.get_global_id()] =
SG.shuffle_xor(vwggid, sgid % SG.get_max_local_range()[0]);
});
});
auto acc = buf.template get_access<access::mode::read_write>();
Expand All @@ -81,12 +81,18 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {

size_t sg_size = sgsizeacc[0];
int SGid = 0;
int SGLid = 0;
int SGBeginGid = 0;
for (int j = 0; j < G; j++) {
if (j % L % sg_size == 0) {
SGid++;
SGLid = 0;
SGBeginGid = j;
}
if (j % L == 0) {
SGid = 0;
SGLid = 0;
SGBeginGid = j;
}
/*GID of middle element in every subgroup*/
exit_if_not_equal_vec<T, N>(
Expand Down Expand Up @@ -115,17 +121,19 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {
exit_if_not_equal_vec(acc2_up[j], vec<T, N>(j - SGid + sg_size),
"shuffle2_up");
}
/* GID XOR SGID */
exit_if_not_equal_vec(acc_xor[j], vec<T, N>(j ^ SGid), "shuffle_xor");
/* Value GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
exit_if_not_equal_vec(acc_xor[j],
vec<T, N>(SGBeginGid + (SGLid ^ (SGid % sg_size))),
"shuffle_xor");
SGLid++;
}
} catch (exception e) {
std::cout << "SYCL exception caught: " << e.what();
exit(1);
}
}

template <typename T>
void check(queue &Queue, size_t G = 240, size_t L = 60) {
template <typename T> void check(queue &Queue, size_t G = 240, size_t L = 60) {
try {
nd_range<1> NdRange(G, L);
buffer<T> buf2(G);
Expand Down Expand Up @@ -171,8 +179,9 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {
acc_up[NdItem.get_global_id()] = SG.shuffle_up<T>(wggid, sgid);
/* Save GID+SGID */
acc_down[NdItem.get_global_id()] = SG.shuffle_down<T>(wggid, sgid);
/* Save GID XOR SGID */
acc_xor[NdItem.get_global_id()] = SG.shuffle_xor<T>(wggid, sgid);
/* Save GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
acc_xor[NdItem.get_global_id()] =
SG.shuffle_xor<T>(wggid, sgid % SG.get_max_local_range()[0]);
});
});
auto acc = buf.template get_access<access::mode::read_write>();
Expand All @@ -186,13 +195,20 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {

size_t sg_size = sgsizeacc[0];
int SGid = 0;
int SGLid = 0;
int SGBeginGid = 0;
for (int j = 0; j < G; j++) {
if (j % L % sg_size == 0) {
SGid++;
SGLid = 0;
SGBeginGid = j;
}
if (j % L == 0) {
SGid = 0;
SGLid = 0;
SGBeginGid = j;
}

/*GID of middle element in every subgroup*/
exit_if_not_equal<T>(acc[j], j / L * L + SGid * sg_size + sg_size / 2,
"shuffle");
Expand All @@ -215,8 +231,10 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {
if (j % L - SGid + sg_size < L) /* Do not go out LG*/
exit_if_not_equal<T>(acc2_up[j], j - SGid + sg_size, "shuffle2_up");
}
/* GID XOR SGID */
exit_if_not_equal<T>(acc_xor[j], j ^ SGid, "shuffle_xor");
/* Value GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
exit_if_not_equal<T>(acc_xor[j], SGBeginGid + (SGLid ^ (SGid % sg_size)),
"shuffle_xor");
SGLid++;
}
} catch (exception e) {
std::cout << "SYCL exception caught: " << e.what();
Expand Down