diff --git a/llvm/lib/SYCLLowerIR/SYCLDeviceRequirements.cpp b/llvm/lib/SYCLLowerIR/SYCLDeviceRequirements.cpp index 6c0f1c952030b..66ae2f10367b1 100644 --- a/llvm/lib/SYCLLowerIR/SYCLDeviceRequirements.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLDeviceRequirements.cpp @@ -38,6 +38,7 @@ static llvm::StringRef ExtractStringFromMDNodeOperand(const MDNode *N, SYCLDeviceRequirements llvm::computeDeviceRequirements(const module_split::ModuleDesc &MD) { SYCLDeviceRequirements Reqs; + bool MultipleReqdWGSize = false; // Process all functions in the module for (const Function &F : MD.getModule()) { if (auto *MDN = F.getMetadata("sycl_used_aspects")) { @@ -64,6 +65,8 @@ llvm::computeDeviceRequirements(const module_split::ModuleDesc &MD) { ExtractUnsignedIntegerFromMDNodeOperand(MDN, I)); if (!Reqs.ReqdWorkGroupSize.has_value()) Reqs.ReqdWorkGroupSize = NewReqdWorkGroupSize; + if (Reqs.ReqdWorkGroupSize != NewReqdWorkGroupSize) + MultipleReqdWGSize = true; } if (auto *MDN = F.getMetadata("sycl_joint_matrix")) { @@ -99,6 +102,14 @@ llvm::computeDeviceRequirements(const module_split::ModuleDesc &MD) { assert(*Reqs.SubGroupSize == static_cast(MDValue)); } } + + // Usually, we would only expect one ReqdWGSize, as the module passed to + // this function would be split according to that. However, when splitting + // is disabled, this cannot be guaranteed. In this case, we reset the value, + // which makes so that no value is reqd_work_group_size data is attached in + // in the device image. + if (MultipleReqdWGSize) + Reqs.ReqdWorkGroupSize.reset(); return Reqs; } diff --git a/sycl/test-e2e/Regression/no-split-reqd-wg-size.cpp b/sycl/test-e2e/Regression/no-split-reqd-wg-size.cpp new file mode 100644 index 0000000000000..fcc2764de8eaa --- /dev/null +++ b/sycl/test-e2e/Regression/no-split-reqd-wg-size.cpp @@ -0,0 +1,28 @@ +// This test checks that with -fsycl-device-code-split=off, kernels +// with different reqd_work_group_size dimensions can be launched. + +// RUN: %{build} -fsycl -fsycl-device-code-split=off -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: hip + +#include + +using namespace sycl; + +#define TEST(...) \ + { \ + range globalRange(__VA_ARGS__); \ + range localRange(__VA_ARGS__); \ + nd_range NDRange(globalRange, localRange); \ + q.parallel_for(NDRange, \ + [=](auto) [[sycl::reqd_work_group_size(__VA_ARGS__)]] {}); \ + } + +int main(int argc, char **argv) { + queue q; + TEST(4); + TEST(4, 5); + TEST(4, 5, 6); + return 0; +}