diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 7e5ce26b5f733..4b348355adc17 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1125,6 +1125,141 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, return success(*map != initialMap); } +/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form +/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`, +/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove) +/// into `replacementsMap`. If no entries were added to `replacementsMap`, +/// nothing was found. +static void shortenAddChainsContainingAll( + AffineExpr e, const llvm::SmallDenseSet &exprsToRemove, + AffineExpr newVal, DenseMap &replacementsMap) { + auto binOp = dyn_cast(e); + if (!binOp) + return; + AffineExpr lhs = binOp.getLHS(); + AffineExpr rhs = binOp.getRHS(); + if (binOp.getKind() != AffineExprKind::Add) { + shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap); + shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap); + return; + } + SmallVector toPreserve; + llvm::SmallDenseSet ourTracker(exprsToRemove); + AffineExpr thisTerm = rhs; + AffineExpr nextTerm = lhs; + + while (thisTerm) { + if (!ourTracker.erase(thisTerm)) { + toPreserve.push_back(thisTerm); + shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal, + replacementsMap); + } + auto nextBinOp = dyn_cast_if_present(nextTerm); + if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) { + thisTerm = nextTerm; + nextTerm = AffineExpr(); + } else { + thisTerm = nextBinOp.getRHS(); + nextTerm = nextBinOp.getLHS(); + } + } + if (!ourTracker.empty()) + return; + // We reverse the terms to be preserved here in order to preserve + // associativity between them. + AffineExpr newExpr = newVal; + for (AffineExpr preserved : llvm::reverse(toPreserve)) + newExpr = newExpr + preserved; + replacementsMap.insert({e, newExpr}); +} + +/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N + +/// ...` (not necessarily in order) where the set of the `x_i` is the set of +/// outputs of an `affine.delinearize_index` whos inverse is that expression, +/// replace that expression with the input of that delinearize_index op. +/// +/// `unitDimInput` is the input that was detected as the potential start to this +/// replacement chain - if it isn't the rightmost result of the delinearization, +/// this method fails. (This is intended to ensure we don't have redundant scans +/// over the same expression). +/// +/// While this currently only handles delinearizations with a constant basis, +/// that isn't a fundamental limitation. +/// +/// This is a utility function for `replaceDimOrSym` below. +static LogicalResult replaceAffineDelinearizeIndexInverseExpression( + AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, + SmallVectorImpl &dims, SmallVectorImpl &syms) { + if (!delinOp.getDynamicBasis().empty()) + return failure(); + if (resultToReplace != delinOp.getMultiIndex().back()) + return failure(); + + MLIRContext *ctx = delinOp.getContext(); + SmallVector resToExpr(delinOp.getNumResults(), AffineExpr()); + for (auto [pos, dim] : llvm::enumerate(dims)) { + auto asResult = dyn_cast_if_present(dim); + if (!asResult) + continue; + if (asResult.getOwner() == delinOp.getOperation()) + resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx); + } + for (auto [pos, sym] : llvm::enumerate(syms)) { + auto asResult = dyn_cast_if_present(sym); + if (!asResult) + continue; + if (asResult.getOwner() == delinOp.getOperation()) + resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx); + } + if (llvm::is_contained(resToExpr, AffineExpr())) + return failure(); + + bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred); + int64_t stride = 1; + llvm::SmallDenseSet expectedExprs; + // This isn't zip_equal since sometimes the delinearize basis is missing a + // size for the first result. + for (auto [binding, size] : llvm::zip( + llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) { + expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx)); + stride *= size; + } + if (resToExpr.size() != delinOp.getStaticBasis().size()) + expectedExprs.insert(resToExpr[0] * stride); + + DenseMap replacements; + AffineExpr delinInExpr = isDimReplacement + ? getAffineDimExpr(dims.size(), ctx) + : getAffineSymbolExpr(syms.size(), ctx); + + for (AffineExpr e : map->getResults()) + shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements); + if (replacements.empty()) + return failure(); + + AffineMap origMap = *map; + if (isDimReplacement) + dims.push_back(delinOp.getLinearIndex()); + else + syms.push_back(delinOp.getLinearIndex()); + *map = origMap.replace(replacements, dims.size(), syms.size()); + + // Blank out dead dimensions and symbols + for (AffineExpr e : resToExpr) { + if (auto d = dyn_cast(e)) { + unsigned pos = d.getPosition(); + if (!map->isFunctionOfDim(pos)) + dims[pos] = nullptr; + } + if (auto s = dyn_cast(e)) { + unsigned pos = s.getPosition(); + if (!map->isFunctionOfSymbol(pos)) + syms[pos] = nullptr; + } + } + return success(); +} + /// Replace all occurrences of AffineExpr at position `pos` in `map` by the /// defining AffineApplyOp expression and operands. /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced. @@ -1157,6 +1292,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map, syms); } + if (auto delinOp = v.getDefiningOp()) { + return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims, + syms); + } + auto affineApply = v.getDefiningOp(); if (!affineApply) return failure(); diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index e56079c1cccd4..1169cd1c29d74 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -2235,6 +2235,136 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind // ----- +// CHECK-LABEL: func @delin_apply_cancel_exact +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref) +// CHECK-COUNT-6: memref.store %[[ARG0]], %[[ARG1]][%[[ARG0]]] +// CHECK-NOT: memref.store +// CHECK: return +func.func @delin_apply_cancel_exact(%arg0: index, %arg1: memref) { + %a:3 = affine.delinearize_index %arg0 into (4, 5) : index, index, index + %b:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index + %c:2 = affine.delinearize_index %arg0 into (20) : index, index + + %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%a#2, %a#1, %a#0] + memref.store %t1, %arg1[%t1] : memref + + %t2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s2 * 20 + s1 * 5)>()[%a#2, %a#1, %a#0] + memref.store %t2, %arg1[%t2] : memref + + %t3 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 20 + s2 * 5 + s0)>()[%a#2, %a#0, %a#1] + memref.store %t3, %arg1[%t3] : memref + + %t4 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%b#2, %b#1, %b#0] + memref.store %t4, %arg1[%t4] : memref + + %t5 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20)>()[%c#1, %c#0] + memref.store %t5, %arg1[%t5] : memref + + %t6 = affine.apply affine_map<()[s0, s1] -> (s1 * 20 + s0)>()[%c#1, %c#0] + memref.store %t6, %arg1[%t5] : memref + + return +} + +// ----- + +// CHECK-LABEL: func @delin_apply_cancel_exact_dim +// CHECK: affine.for %[[arg1:.+]] = 0 to 256 +// CHECK: memref.store %[[arg1]] +// CHECK: return +func.func @delin_apply_cancel_exact_dim(%arg0: memref) { + affine.for %arg1 = 0 to 256 { + %a:3 = affine.delinearize_index %arg1 into (2, 2, 64) : index, index, index + %i = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 128 + d2 * 64)>(%a#2, %a#0, %a#1) + memref.store %i, %arg0[%i] : memref + } + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 512)> +// CHECK-LABEL: func @delin_apply_cancel_const_term +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref) +// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_const_term(%arg0: index, %arg1: memref) { + %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index + + %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 128 + s2 * 64 + 512)>()[%a#2, %a#0, %a#1] + memref.store %t1, %arg1[%t1] : memref + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 512)> +// CHECK-LABEL: func @delin_apply_cancel_var_term +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref, %[[ARG2:.+]]: index) +// CHECK: affine.apply #[[$MAP]]()[%[[ARG2]], %[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_var_term(%arg0: index, %arg1: memref, %arg2: index) { + %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index + + %t1 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 * 128 + s2 * 64 + s3 + 512)>()[%a#2, %a#0, %a#1, %arg2] + memref.store %t1, %arg1[%t1] : memref + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2 + s0 ceildiv 4)> +// CHECK-LABEL: func @delin_apply_cancel_nested_exprs +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref) +// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_nested_exprs(%arg0: index, %arg1: memref) { + %a:2 = affine.delinearize_index %arg0 into (20) : index, index + + %t1 = affine.apply affine_map<()[s0, s1] -> ((s0 + s1 * 20) ceildiv 4 + (s1 * 20 + s0) * 2)>()[%a#1, %a#0] + memref.store %t1, %arg1[%t1] : memref + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @delin_apply_cancel_preserve_rotation +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref) +// CHECK: %[[A:.+]]:2 = affine.delinearize_index %[[ARG0]] into (20) +// CHECK: affine.apply #[[$MAP]]()[%[[A]]#1, %[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_preserve_rotation(%arg0: index, %arg1: memref) { + %a:2 = affine.delinearize_index %arg0 into (20) : index, index + + %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20 + s0)>()[%a#1, %a#0] + memref.store %t1, %arg1[%t1] : memref + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 5)> +// CHECK-LABEL: func @delin_apply_dont_cancel_partial +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref) +// CHECK: %[[A:.+]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 5) +// CHECK: affine.apply #[[$MAP]]()[%[[A]]#2, %[[A]]#1] +// CHECK: return +func.func @delin_apply_dont_cancel_partial(%arg0: index, %arg1: memref) { + %a:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index + + %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 5)>()[%a#2, %a#1] + memref.store %t1, %arg1[%t1] : memref + + return +} + +// ----- + // CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index // CHECK-SAME: (%[[ARG0:.*]]: index) // CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index