Skip to content

Commit 698fcb1

Browse files
authored
[mlir][affine] Set overflow flags when lowering [de]linearize_index (#139612)
By analogy to some changess to the affine.apply lowering which put `nsw`s on various multiplications, add appropritae overflow flags to the multiplications and additions that're emitted when lowering affine.delinearize_index and affine.linearize_index to arith ops.
1 parent e581f1c commit 698fcb1

File tree

3 files changed

+71
-29
lines changed

3 files changed

+71
-29
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
11131113
Due to the constraints of affine maps, all the basis elements must
11141114
be strictly positive. A dynamic basis element being 0 or negative causes
11151115
undefined behavior.
1116+
1117+
As with other affine operations, lowerings of delinearize_index may assume
1118+
that the underlying computations do not overflow the index type in a signed sense
1119+
- that is, the product of all basis elements is positive as an `index` as well.
11161120
}];
11171121

11181122
let arguments = (ins Index:$linear_index,
@@ -1195,9 +1199,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
11951199
If the `disjoint` property is present, this is an optimization hint that,
11961200
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
11971201
except that `%idx_0` may be negative to make the index as a whole negative.
1202+
In addition, `disjoint` is an assertion that all bases elements are non-negative.
11981203

11991204
Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
12001205

1206+
As with other affine ops, undefined behavior occurs if the linearization
1207+
computation overflows in the signed sense.
1208+
12011209
Example:
12021210

12031211
```mlir

mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ using namespace mlir::affine;
3535
///
3636
/// If excess dynamic values are provided, the values at the beginning
3737
/// will be ignored. This allows for dropping the outer bound without
38-
/// needing to manipulate the dynamic value array.
38+
/// needing to manipulate the dynamic value array. `knownPositive`
39+
/// indicases that the values being used to compute the strides are known
40+
/// to be non-negative.
3941
static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
4042
ValueRange dynamicBasis,
41-
ArrayRef<int64_t> staticBasis) {
43+
ArrayRef<int64_t> staticBasis,
44+
bool knownNonNegative) {
4245
if (staticBasis.empty())
4346
return {};
4447

@@ -47,11 +50,18 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
4750
size_t dynamicIndex = dynamicBasis.size();
4851
Value dynamicPart = nullptr;
4952
int64_t staticPart = 1;
53+
// The products of the strides can't have overflow by definition of
54+
// affine.*_index.
55+
arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
56+
if (knownNonNegative)
57+
ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
5058
for (int64_t elem : llvm::reverse(staticBasis)) {
5159
if (ShapedType::isDynamic(elem)) {
60+
// Note: basis elements and their products are, definitionally,
61+
// non-negative, so `nuw` is justified.
5262
if (dynamicPart)
5363
dynamicPart = rewriter.create<arith::MulIOp>(
54-
loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
64+
loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
5565
else
5666
dynamicPart = dynamicBasis[dynamicIndex - 1];
5767
--dynamicIndex;
@@ -65,7 +75,8 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
6575
Value stride =
6676
rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
6777
if (dynamicPart)
68-
stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
78+
stride =
79+
rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
6980
result.push_back(stride);
7081
}
7182
}
@@ -96,7 +107,8 @@ struct LowerDelinearizeIndexOps
96107
SmallVector<Value> results;
97108
results.reserve(numResults);
98109
SmallVector<Value> strides =
99-
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
110+
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
111+
/*knownNonNegative=*/true);
100112

101113
Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
102114

@@ -108,7 +120,11 @@ struct LowerDelinearizeIndexOps
108120
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
109121
Value remainderNegative = rewriter.create<arith::CmpIOp>(
110122
loc, arith::CmpIPredicate::slt, remainder, zero);
111-
Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
123+
// If the correction is relevant, this term is <= stride, which is known
124+
// to be positive in `index`. Otherwise, while 2 * stride might overflow,
125+
// this branch won't be taken, so the risk of `poison` is fine.
126+
Value corrected = rewriter.create<arith::AddIOp>(
127+
loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
112128
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
113129
corrected, remainder);
114130
return mod;
@@ -155,7 +171,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
155171
staticBasis = staticBasis.drop_front();
156172

157173
SmallVector<Value> strides =
158-
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
174+
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
175+
/*knownNonNegative=*/op.getDisjoint());
159176
SmallVector<std::pair<Value, int64_t>> scaledValues;
160177
scaledValues.reserve(numIndexes);
161178

@@ -164,8 +181,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
164181
// our hands on an `OpOperand&` for the loop invariant counting function.
165182
for (auto [stride, idxOp] :
166183
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
167-
Value scaledIdx =
168-
rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
184+
Value scaledIdx = rewriter.create<arith::MulIOp>(
185+
loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
169186
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
170187
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
171188
}
@@ -182,7 +199,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
182199
for (auto [scaledValue, numHoistableLoops] :
183200
llvm::drop_begin(scaledValues)) {
184201
std::ignore = numHoistableLoops;
185-
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
202+
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
203+
arith::IntegerOverflowFlags::nsw);
186204
}
187205
rewriter.replaceOp(op, result);
188206
return success();

mlir/test/Dialect/Affine/affine-expand-index-ops.mlir

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
// CHECK: %[[N:.+]] = arith.floordivsi %[[IDX]], %[[C50176]]
99
// CHECK-DAG: %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[C50176]]
1010
// CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]]
11-
// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[C50176]]
11+
// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[C50176]] overflow<nsw>
1212
// CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]]
1313
// CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[C224]]
1414
// CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[C224]]
1515
// CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]]
16-
// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[C224]]
16+
// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[C224]] overflow<nsw>
1717
// CHECK: %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]]
1818
// CHECK: return %[[N]], %[[P]], %[[Q]]
1919
func.func @delinearize_static_basis(%linear_index: index) -> (index, index, index) {
@@ -30,16 +30,16 @@ func.func @delinearize_static_basis(%linear_index: index) -> (index, index, inde
3030
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
3131
// CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] :
3232
// CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] :
33-
// CHECK: %[[STRIDE1:.+]] = arith.muli %[[DIM2]], %[[DIM1]]
33+
// CHECK: %[[STRIDE1:.+]] = arith.muli %[[DIM2]], %[[DIM1]] overflow<nsw, nuw>
3434
// CHECK: %[[N:.+]] = arith.floordivsi %[[IDX]], %[[STRIDE1]]
3535
// CHECK-DAG: %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[STRIDE1]]
3636
// CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]]
37-
// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]]
37+
// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]] overflow<nsw>
3838
// CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]]
3939
// CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[DIM2]]
4040
// CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[DIM2]]
4141
// CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]]
42-
// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]]
42+
// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]] overflow<nsw>
4343
// CHECK: %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]]
4444
// CHECK: return %[[N]], %[[P]], %[[Q]]
4545
func.func @delinearize_dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) {
@@ -58,10 +58,10 @@ func.func @delinearize_dynamic_basis(%linear_index: index, %src: memref<?x?x?xf3
5858
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
5959
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
6060
// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index
61-
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[C15]]
62-
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[C5]]
63-
// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]]
64-
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]]
61+
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[C15]] overflow<nsw>
62+
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[C5]] overflow<nsw>
63+
// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] overflow<nsw>
64+
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] overflow<nsw>
6565
// CHECK: return %[[val_1]]
6666
func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
6767
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
@@ -72,11 +72,11 @@ func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
7272

7373
// CHECK-LABEL: @linearize_dynamic
7474
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
75-
// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]]
76-
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]]
77-
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]]
78-
// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]]
79-
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]]
75+
// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]] overflow<nsw>
76+
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]] overflow<nsw>
77+
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]] overflow<nsw>
78+
// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] overflow<nsw>
79+
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] overflow<nsw>
8080
// CHECK: return %[[val_1]]
8181
func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
8282
// Note: no outer bounds
@@ -86,17 +86,33 @@ func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: in
8686

8787
// -----
8888

89+
// CHECK-LABEL: @linearize_dynamic_disjoint
90+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
91+
// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]] overflow<nsw, nuw>
92+
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]] overflow<nsw>
93+
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]] overflow<nsw>
94+
// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] overflow<nsw>
95+
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] overflow<nsw>
96+
// CHECK: return %[[val_1]]
97+
func.func @linearize_dynamic_disjoint(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
98+
// Note: no outer bounds
99+
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
100+
func.return %0 : index
101+
}
102+
103+
// -----
104+
89105
// CHECK-LABEL: @linearize_sort_adds
90106
// CHECK-SAME: (%[[arg0:.+]]: memref<?xi32>, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
91107
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
92108
// CHECK: scf.for %[[arg3:.+]] = %{{.*}} to %[[arg2]] step %{{.*}} {
93109
// CHECK: scf.for %[[arg4:.+]] = %{{.*}} to %[[C4]] step %{{.*}} {
94-
// CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]]
95-
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]]
96-
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]]
110+
// CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]] overflow<nsw, nuw>
111+
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]] overflow<nsw>
112+
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]] overflow<nsw>
97113
// Note: even though %arg3 has a lower stride, we add it first
98-
// CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]]
99-
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]]
114+
// CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]] overflow<nsw>
115+
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]] overflow<nsw>
100116
// CHECK: memref.store %{{.*}}, %[[arg0]][%[[val_1]]]
101117
func.func @linearize_sort_adds(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
102118
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)