-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][NVVM] Fix assertion failure for insufficient parsing validation of nvvm dialect PureSpecialRangeableRegisterOp #163434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Fix assertion failure for insufficient parsing validation of nvvm dialect PureSpecialRangeableRegisterOp #163434
Conversation
…ialect with PureSpecialRangeableRegisterOp
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Stefan Mada (smada3) ChangesThe nvvm dialect instruction PureSpecialRangeableRegisterOp will trigger an assertion failure in LLVM's constant range class when the lower and upper range bounds are equal, but not equal to the integer minimum or max (as required by constant ranges). This requirement is at line 56 of ConstantRange.cpp:
However, you can write an NVVM dialect operation such as:
which triggers this assertion. This change adds a fix to ensure that this requirement is also enforced by NVVM. Full diff: https://github.com/llvm/llvm-project/pull/163434.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 89fbeb7270a38..e4e23ecf77d8b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -279,6 +279,23 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
SetIntRangeFn setResultRanges) {
nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges);
}
+
+ // Verify the range attribute satisfies LLVM ConstantRange constructor requirements.
+ ::llvm::LogicalResult $cppClass::verify() {
+ auto rangeAttr = getRange();
+ if (!rangeAttr)
+ return ::mlir::success(); // No range specified, validation passes
+
+ const ::llvm::APInt &lower = rangeAttr->getLower();
+ const ::llvm::APInt &upper = rangeAttr->getUpper();
+
+ // Check LLVM ConstantRange constructor condition
+ if (!(lower != upper || (lower.isMaxValue() || lower.isMinValue()))) {
+ return emitOpError("invalid range attribute: range must be a valid constant range");
+ }
+
+ return ::mlir::success();
+ }
}];
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0b3615487716d..27727d9bb5836 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -559,3 +559,13 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ
%res = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %try_cancel_response : i1
llvm.return
}
+
+
+// -----
+
+// Test for range validation - invalid range where lower == upper but not at extremes
+func.func @invalid_range_equal_bounds() {
+ // expected-error @below {{invalid range attribute: range must be a valid constant range}}
+ %0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
+ return
+}
|
|
||
// Check LLVM ConstantRange constructor condition | ||
if (!(lower != upper || (lower.isMaxValue() || lower.isMinValue()))) { | ||
return emitOpError("invalid range attribute: range must be a valid constant range"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add more information on what the criteria is and also print the lower and upper values.
if (!rangeAttr) | ||
return ::mlir::success(); // No range specified, validation passes | ||
|
||
const ::llvm::APInt &lower = rangeAttr->getLower(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have never verified the range attribute if I'm not mistaken.
I think we can do that for the other ops as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any other ops that I missed that have a range attribute? This bug was initially targeting nvvm.read.ptx.sreg.tid.x but the fix applies to all the special registers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but let's wait on @grypp to have another look as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest revision LGTM
@grypp If it looks good, please merge on my behalf. I don't have commit permissions. |
Not directly related to this PR, but I was wondering — can we always add and verify the For example, our upper limits have been fixed since CUDA 2.0:
We could also use |
The nvvm dialect instruction PureSpecialRangeableRegisterOp will trigger an assertion failure in LLVM's constant range class when the lower and upper range bounds are equal, but not equal to the integer minimum or max (as required by constant ranges). This requirement is at line 56 of ConstantRange.cpp:
assert((Lower != Upper || (Lower.isMaxValue() || Lower.isMinValue())) && "Lower == Upper, but they aren't min or max value!");
However, you can write an NVVM dialect operation such as:
%0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
which triggers this assertion. This change adds a fix to ensure that this requirement is also enforced by NVVM.