diff --git a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h index e46a576f1d48e..4a4116312981a 100644 --- a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h +++ b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h @@ -30,6 +30,13 @@ constexpr uint32_t kMagicNumber = 0x07230203; /// The serializer tool ID registered to the Khronos Group constexpr uint32_t kGeneratorNumber = 22; +/// Max number of words +/// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_universal_limits +constexpr uint32_t kMaxWordCount = 65535; + +/// Max number of words for literal +constexpr uint32_t kMaxLiteralWordCount = kMaxWordCount - 3; + /// Appends a SPRI-V module header to `header` with the given `version` and /// `idBound`. void appendModuleHeader(SmallVectorImpl &header, diff --git a/mlir/lib/Target/SPIRV/SPIRVBinaryUtils.cpp b/mlir/lib/Target/SPIRV/SPIRVBinaryUtils.cpp index 31205d8f408f1..0ec468d4c1665 100644 --- a/mlir/lib/Target/SPIRV/SPIRVBinaryUtils.cpp +++ b/mlir/lib/Target/SPIRV/SPIRVBinaryUtils.cpp @@ -13,6 +13,9 @@ #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "llvm/Config/llvm-config.h" // for LLVM_VERSION_MAJOR +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "spirv-binary-utils" using namespace mlir; @@ -67,8 +70,19 @@ uint32_t spirv::getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode) { void spirv::encodeStringLiteralInto(SmallVectorImpl &binary, StringRef literal) { // We need to encode the literal and the null termination. - auto encodingSize = literal.size() / 4 + 1; - auto bufferStartSize = binary.size(); + size_t encodingSize = literal.size() / 4 + 1; + size_t sizeOfDataToCopy = literal.size(); + if (encodingSize >= kMaxLiteralWordCount) { + // Reserve one word for the null termination. + encodingSize = kMaxLiteralWordCount - 1; + // Do not override the last word (null termination) when copying. + sizeOfDataToCopy = (encodingSize - 1) * 4; + LLVM_DEBUG(llvm::dbgs() + << "Truncating string literal to max size (" + << (kMaxLiteralWordCount - 1) << "): " << literal << "\n"); + } + size_t bufferStartSize = binary.size(); binary.resize(bufferStartSize + encodingSize, 0); - std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size()); + std::memcpy(binary.data() + bufferStartSize, literal.data(), + sizeOfDataToCopy); }