|
20 | 20 | #include "mlir/Pass/Pass.h" |
21 | 21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
22 | 22 |
|
| 23 | +#include "llvm/ADT/STLExtras.h" |
23 | 24 | #include "llvm/Support/DebugLog.h" |
24 | 25 |
|
25 | 26 | #include <numeric> |
@@ -265,8 +266,7 @@ loadStoreFromTransfer(PatternRewriter &rewriter, |
265 | 266 | if (isPacked) |
266 | 267 | src = collapseLastDim(rewriter, src); |
267 | 268 | int64_t rows = vecShape[0]; |
268 | | - int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1, |
269 | | - std::multiplies<int64_t>()); |
| 269 | + int64_t cols = llvm::product_of(vecShape.drop_front()); |
270 | 270 | auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); |
271 | 271 |
|
272 | 272 | Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); |
@@ -336,8 +336,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter, |
336 | 336 |
|
337 | 337 | ArrayRef<int64_t> shape = vecTy.getShape(); |
338 | 338 | int64_t rows = shape[0]; |
339 | | - int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1, |
340 | | - std::multiplies<int64_t>()); |
| 339 | + int64_t cols = llvm::product_of(shape.drop_front()); |
341 | 340 | auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); |
342 | 341 |
|
343 | 342 | return amx::TileLoadOp::create(rewriter, loc, tileType, buf, |
|
0 commit comments