Skip to content

Conversation

smada3
Copy link
Contributor

@smada3 smada3 commented Oct 14, 2025

The nvvm dialect instruction PureSpecialRangeableRegisterOp will trigger an assertion failure in LLVM's constant range class when the lower and upper range bounds are equal, but not equal to the integer minimum or max (as required by constant ranges). This requirement is at line 56 of ConstantRange.cpp:

assert((Lower != Upper || (Lower.isMaxValue() || Lower.isMinValue())) && "Lower == Upper, but they aren't min or max value!");

However, you can write an NVVM dialect operation such as:

%0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32

which triggers this assertion. This change adds a fix to ensure that this requirement is also enforced by NVVM.

@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Stefan Mada (smada3)

Changes

The nvvm dialect instruction PureSpecialRangeableRegisterOp will trigger an assertion failure in LLVM's constant range class when the lower and upper range bounds are equal, but not equal to the integer minimum or max (as required by constant ranges). This requirement is at line 56 of ConstantRange.cpp:

assert((Lower != Upper || (Lower.isMaxValue() || Lower.isMinValue())) &amp;&amp; "Lower == Upper, but they aren't min or max value!");

However, you can write an NVVM dialect operation such as:

%0 = nvvm.read.ptx.sreg.warpsize range &lt;i32, 32, 32&gt; : i32

which triggers this assertion. This change adds a fix to ensure that this requirement is also enforced by NVVM.


Full diff: https://github.com/llvm/llvm-project/pull/163434.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+17)
  • (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 89fbeb7270a38..e4e23ecf77d8b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -279,6 +279,23 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
         SetIntRangeFn setResultRanges) {
         nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges);
     }
+
+    // Verify the range attribute satisfies LLVM ConstantRange constructor requirements.
+    ::llvm::LogicalResult $cppClass::verify() {
+      auto rangeAttr = getRange();
+      if (!rangeAttr)
+        return ::mlir::success(); // No range specified, validation passes
+      
+      const ::llvm::APInt &lower = rangeAttr->getLower();
+      const ::llvm::APInt &upper = rangeAttr->getUpper();
+      
+      // Check LLVM ConstantRange constructor condition
+      if (!(lower != upper || (lower.isMaxValue() || lower.isMinValue()))) {
+        return emitOpError("invalid range attribute: range must be a valid constant range");
+      }
+      
+      return ::mlir::success();
+    }
   }];
 
 }
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0b3615487716d..27727d9bb5836 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -559,3 +559,13 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ
   %res = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %try_cancel_response : i1
   llvm.return
 }
+
+
+// -----
+
+// Test for range validation - invalid range where lower == upper but not at extremes
+func.func @invalid_range_equal_bounds() {
+  // expected-error @below {{invalid range attribute: range must be a valid constant range}}
+  %0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
+  return
+}


// Check LLVM ConstantRange constructor condition
if (!(lower != upper || (lower.isMaxValue() || lower.isMinValue()))) {
return emitOpError("invalid range attribute: range must be a valid constant range");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add more information on what the criteria is and also print the lower and upper values.

if (!rangeAttr)
return ::mlir::success(); // No range specified, validation passes

const ::llvm::APInt &lower = rangeAttr->getLower();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have never verified the range attribute if I'm not mistaken.

I think we can do that for the other ops as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any other ops that I missed that have a range attribute? This bug was initially targeting nvvm.read.ptx.sreg.tid.x but the fix applies to all the special registers.

@smada3 smada3 requested review from grypp and joker-eph October 14, 2025 20:51
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but let's wait on @grypp to have another look as well

@joker-eph joker-eph changed the title [MLIR][NVVM] Fixed assertion failure for insufficient parsing validation of nvvm dialect PureSpecialRangeableRegisterOp [MLIR][NVVM] Fix assertion failure for insufficient parsing validation of nvvm dialect PureSpecialRangeableRegisterOp Oct 14, 2025
Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest revision LGTM

@smada3
Copy link
Contributor Author

smada3 commented Oct 15, 2025

@grypp If it looks good, please merge on my behalf. I don't have commit permissions.

@grypp
Copy link
Member

grypp commented Oct 15, 2025

Not directly related to this PR, but I was wondering — can we always add and verify the range attribute in a meaningful way? I implemented this downstream, maybe we should have this nicely in NVVM dialect.

For example, our upper limits have been fixed since CUDA 2.0:

threadIdx.x -> 0 - 1024
threadIdx.y -> 0 - 1024
threadIdx.z -> 0 - 64

blockIdx.x  -> 0 - 2^32 - 1
blockIdx.y  -> 0 - 65535
blockIdx.z  -> 0 - 65535

laneid      -> 0 - 32
warpsize    -> 32

We could also use #target SM information if we expect these numbers to change in the future.

@grypp grypp merged commit e712871 into llvm:main Oct 15, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants