Skip to content

Conversation

krzysz00
Copy link
Contributor

If an affine.apply uses every result of an
affine.delinearize_index operaration in an expresession of the form x_0 * S_0 + x_1 * S_1 + ... + x_n * S_n + ..., where S_i is the "stride" of the i-th delinerization result (the value it got divided by), then, that chain of additions contains the inverse of the affine.delinearize_index.

We don't want to compose affine.delinearize_index into affine.apply in general, since this leads to "simplifications" (mainly the x % y => x - (x / y) * y rewrite) thate are bad for code generation and algetbraic reasoning. However, if we do see an exact inverse, we should cancel it out.

If an `affine.apply` uses every result of an
`affine.delinearize_index` operaration in an expresession of the form
x_0 * S_0 + x_1 * S_1 + ... + x_n * S_n + ..., where S_i is the
"stride" of the i-th delinerization result (the value it got divided
by), then, that chain of additions contains the inverse of the
affine.delinearize_index.

We don't want to compose affine.delinearize_index into affine.apply in
general, since this leads to "simplifications" (mainly the
`x % y => x - (x / y) * y` rewrite) thate are bad for code generation
and algetbraic reasoning. However, if we do see an exact inverse, we
should cancel it out.
@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: Krzysztof Drewniak (krzysz00)

Changes

If an affine.apply uses every result of an
affine.delinearize_index operaration in an expresession of the form x_0 * S_0 + x_1 * S_1 + ... + x_n * S_n + ..., where S_i is the "stride" of the i-th delinerization result (the value it got divided by), then, that chain of additions contains the inverse of the affine.delinearize_index.

We don't want to compose affine.delinearize_index into affine.apply in general, since this leads to "simplifications" (mainly the x % y => x - (x / y) * y rewrite) thate are bad for code generation and algetbraic reasoning. However, if we do see an exact inverse, we should cancel it out.


Full diff: https://github.com/llvm/llvm-project/pull/163440.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+141)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+130)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7e5ce26b5f733..291b729e07aef 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1125,6 +1125,142 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
   return success(*map != initialMap);
 }
 
+/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
+/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
+/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
+/// into `replacementsMap`. If no entries were added to `replacementsMap`,
+/// nothing was found.
+static void shortenAddChainsContainingAll(
+    AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
+    AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
+  auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
+  if (!binOp)
+    return;
+  AffineExpr lhs = binOp.getLHS();
+  AffineExpr rhs = binOp.getRHS();
+  if (binOp.getKind() != AffineExprKind::Add) {
+    shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
+    shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
+    return;
+  }
+  SmallVector<AffineExpr> toPreserve;
+  llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
+  AffineExpr thisTerm = rhs;
+  AffineExpr nextTerm = lhs;
+
+  while (thisTerm) {
+    if (!ourTracker.erase(thisTerm)) {
+      toPreserve.push_back(thisTerm);
+      shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
+                                    replacementsMap);
+    }
+    auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
+    if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
+      thisTerm = nextTerm;
+      nextTerm = AffineExpr();
+    } else {
+      thisTerm = nextBinOp.getRHS();
+      nextTerm = nextBinOp.getLHS();
+    }
+  }
+  if (!ourTracker.empty())
+    return;
+  // We reverse the terms to be preserved here in order to preserve
+  // associativity between them.
+  AffineExpr newExpr = newVal;
+  for (AffineExpr preserved : llvm::reverse(toPreserve))
+    newExpr = newExpr + preserved;
+  replacementsMap.insert({e, newExpr});
+}
+
+/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
+/// ...` (not necessarily in order) where the set of the `x_i` is the set of
+/// outputs of an `affine.delinearize_index` whos inverse is that expression,
+/// replace that expression with the input of that delinearize_index op.
+///
+/// `unitDimInput` is the input that was detected as the potential start to this
+/// replacement chain - if it isn't the rightmost result of the delinearization,
+/// this method fails. (This is intended to ensure we don't have redundant scans
+/// over the same expression).
+///
+/// While this currently only handles delinearizations with a constant basis,
+/// that isn't a fundamental limitation.
+///
+/// This is a utility function for `replaceDimOrSym` below.
+static LogicalResult replaceAffineDelinearizeIndexInverseExpression(
+    AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
+    SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) {
+  if (!delinOp.getDynamicBasis().empty())
+    return failure();
+  if (resultToReplace != delinOp.getMultiIndex().back())
+    return failure();
+
+  MLIRContext *ctx = delinOp.getContext();
+  SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
+  for (auto [pos, dim] : llvm::enumerate(dims)) {
+    auto asResult = dyn_cast_if_present<OpResult>(dim);
+    if (!asResult)
+      continue;
+    if (asResult.getOwner() == delinOp.getOperation())
+      resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
+  }
+  for (auto [pos, sym] : llvm::enumerate(syms)) {
+    auto asResult = dyn_cast_if_present<OpResult>(sym);
+    if (!asResult)
+      continue;
+    if (asResult.getOwner() == delinOp.getOperation())
+      resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
+  }
+  if (llvm::any_of(resToExpr, [](auto e) { return e == AffineExpr(); })) {
+    return failure();
+  }
+
+  bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
+  int64_t stride = 1;
+  llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
+  // This isn't zip_equal since sometimes the delinearize basis is missing a
+  // size for the first result.
+  for (auto [binding, size] : llvm::zip(
+           llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
+    expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
+    stride *= size;
+  }
+  if (resToExpr.size() != delinOp.getStaticBasis().size())
+    expectedExprs.insert(resToExpr[0] * stride);
+
+  DenseMap<AffineExpr, AffineExpr> replacements;
+  AffineExpr delinInExpr = isDimReplacement
+                               ? getAffineDimExpr(dims.size(), ctx)
+                               : getAffineSymbolExpr(syms.size(), ctx);
+
+  for (AffineExpr e : map->getResults())
+    shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
+  if (replacements.empty())
+    return failure();
+
+  AffineMap origMap = *map;
+  if (isDimReplacement)
+    dims.push_back(delinOp.getLinearIndex());
+  else
+    syms.push_back(delinOp.getLinearIndex());
+  *map = origMap.replace(replacements, dims.size(), syms.size());
+
+  // Blank out dead dimensions and symbols
+  for (AffineExpr e : resToExpr) {
+    if (auto d = dyn_cast<AffineDimExpr>(e)) {
+      unsigned pos = d.getPosition();
+      if (!map->isFunctionOfDim(pos))
+        dims[pos] = nullptr;
+    }
+    if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
+      unsigned pos = s.getPosition();
+      if (!map->isFunctionOfSymbol(pos))
+        syms[pos] = nullptr;
+    }
+  }
+  return success();
+}
+
 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
 /// defining AffineApplyOp expression and operands.
 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
@@ -1157,6 +1293,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
                                                  syms);
   }
 
+  if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
+    return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
+                                                          syms);
+  }
+
   auto affineApply = v.getDefiningOp<AffineApplyOp>();
   if (!affineApply)
     return failure();
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index e56079c1cccd4..1169cd1c29d74 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -2235,6 +2235,136 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
 
 // -----
 
+// CHECK-LABEL: func @delin_apply_cancel_exact
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK-COUNT-6: memref.store %[[ARG0]], %[[ARG1]][%[[ARG0]]]
+// CHECK-NOT: memref.store
+// CHECK: return
+func.func @delin_apply_cancel_exact(%arg0:  index, %arg1: memref<?xindex>) {
+  %a:3 = affine.delinearize_index %arg0 into (4, 5) : index, index, index
+  %b:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
+  %c:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+  %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%a#2, %a#1, %a#0]
+  memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+  %t2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s2 * 20 + s1 * 5)>()[%a#2, %a#1, %a#0]
+  memref.store %t2, %arg1[%t2] : memref<?xindex>
+
+  %t3 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 20 + s2 * 5 + s0)>()[%a#2, %a#0, %a#1]
+  memref.store %t3, %arg1[%t3] : memref<?xindex>
+
+  %t4 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%b#2, %b#1, %b#0]
+  memref.store %t4, %arg1[%t4] : memref<?xindex>
+
+  %t5 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20)>()[%c#1, %c#0]
+  memref.store %t5, %arg1[%t5] : memref<?xindex>
+
+  %t6 = affine.apply affine_map<()[s0, s1] -> (s1 * 20 + s0)>()[%c#1, %c#0]
+  memref.store %t6, %arg1[%t5] : memref<?xindex>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @delin_apply_cancel_exact_dim
+// CHECK: affine.for %[[arg1:.+]] = 0 to 256
+// CHECK: memref.store %[[arg1]]
+// CHECK: return
+func.func @delin_apply_cancel_exact_dim(%arg0: memref<?xindex>) {
+  affine.for %arg1 = 0 to 256 {
+    %a:3 = affine.delinearize_index %arg1 into (2, 2, 64) : index, index, index
+    %i = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 128 + d2 * 64)>(%a#2, %a#0, %a#1)
+    memref.store %i, %arg0[%i] : memref<?xindex>
+  }
+  return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 512)>
+// CHECK-LABEL: func @delin_apply_cancel_const_term
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_const_term(%arg0:  index, %arg1: memref<?xindex>) {
+  %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index
+
+  %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 128 + s2 * 64 + 512)>()[%a#2, %a#0, %a#1]
+  memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+  return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 512)>
+// CHECK-LABEL: func @delin_apply_cancel_var_term
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>, %[[ARG2:.+]]: index)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG2]], %[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_var_term(%arg0:  index, %arg1: memref<?xindex>, %arg2: index) {
+  %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index
+
+  %t1 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 * 128 + s2 * 64 + s3 + 512)>()[%a#2, %a#0, %a#1, %arg2]
+  memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+  return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2 + s0 ceildiv 4)>
+// CHECK-LABEL: func @delin_apply_cancel_nested_exprs
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_nested_exprs(%arg0:  index, %arg1: memref<?xindex>) {
+  %a:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+  %t1 = affine.apply affine_map<()[s0, s1] -> ((s0 + s1 * 20) ceildiv 4 + (s1 * 20 + s0) * 2)>()[%a#1, %a#0]
+  memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+  return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @delin_apply_cancel_preserve_rotation
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: %[[A:.+]]:2 = affine.delinearize_index %[[ARG0]] into (20)
+// CHECK: affine.apply #[[$MAP]]()[%[[A]]#1, %[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_preserve_rotation(%arg0:  index, %arg1: memref<?xindex>) {
+  %a:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+  %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20 + s0)>()[%a#1, %a#0]
+  memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+  return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 5)>
+// CHECK-LABEL: func @delin_apply_dont_cancel_partial
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: %[[A:.+]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 5)
+// CHECK: affine.apply #[[$MAP]]()[%[[A]]#2, %[[A]]#1]
+// CHECK: return
+func.func @delin_apply_dont_cancel_partial(%arg0:  index, %arg1: memref<?xindex>) {
+  %a:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
+
+  %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 5)>()[%a#2, %a#1]
+  memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
 // CHECK-SAME:    (%[[ARG0:.*]]: index)
 // CHECK:         %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@krzysz00
Copy link
Contributor Author

(For the record, IREE benchmark performance was ~flat or the sort of 1.01x that might also be noise)

@krzysz00 krzysz00 merged commit 8c05b5c into llvm:main Oct 16, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants