diff --git a/llvm/tools/sycl-post-link/SYCLDeviceRequirements.cpp b/llvm/tools/sycl-post-link/SYCLDeviceRequirements.cpp index a9c791877a079..93650640f97cc 100644 --- a/llvm/tools/sycl-post-link/SYCLDeviceRequirements.cpp +++ b/llvm/tools/sycl-post-link/SYCLDeviceRequirements.cpp @@ -21,11 +21,18 @@ using namespace llvm; void llvm::getSYCLDeviceRequirements( const module_split::ModuleDesc &MD, std::map &Requirements) { - auto ExtractIntegerFromMDNodeOperand = [=](const MDNode *N, - unsigned OpNo) -> int32_t { + auto ExtractSignedIntegerFromMDNodeOperand = [=](const MDNode *N, + unsigned OpNo) -> int64_t { Constant *C = cast(N->getOperand(OpNo).get())->getValue(); - return static_cast(C->getUniqueInteger().getSExtValue()); + return C->getUniqueInteger().getSExtValue(); + }; + + auto ExtractUnsignedIntegerFromMDNodeOperand = + [=](const MDNode *N, unsigned OpNo) -> uint64_t { + Constant *C = + cast(N->getOperand(OpNo).get())->getValue(); + return C->getUniqueInteger().getZExtValue(); }; // { LLVM-IR metadata name , [SYCL/Device requirements] property name }, see: @@ -42,11 +49,15 @@ void llvm::getSYCLDeviceRequirements( for (const Function &F : MD.getModule()) { if (const MDNode *MDN = F.getMetadata(MDName)) { for (size_t I = 0, E = MDN->getNumOperands(); I < E; ++I) { - // Don't put internal aspects (with negative integer value) into the - // requirements, they are used only for device image splitting. - auto Val = ExtractIntegerFromMDNodeOperand(MDN, I); - if (Val >= 0) - Values.insert(Val); + if (std::string(MDName) == "sycl_used_aspects") { + // Don't put internal aspects (with negative integer value) into the + // requirements, they are used only for device image splitting. + auto Val = ExtractSignedIntegerFromMDNodeOperand(MDN, I); + if (Val >= 0) + Values.insert(Val); + } else { + Values.insert(ExtractUnsignedIntegerFromMDNodeOperand(MDN, I)); + } } } } @@ -69,8 +80,7 @@ void llvm::getSYCLDeviceRequirements( for (const Function *F : MD.entries()) { if (auto *MDN = F->getMetadata("intel_reqd_sub_group_size")) { assert(MDN->getNumOperands() == 1); - auto MDValue = ExtractIntegerFromMDNodeOperand(MDN, 0); - assert(MDValue >= 0); + auto MDValue = ExtractUnsignedIntegerFromMDNodeOperand(MDN, 0); if (!SubGroupSize) SubGroupSize = MDValue; else