Skip to content

Commit 0820266

Browse files
authored
[mlir] Use llvm accumulate wrappers. NFCI. (#162957)
Use wrappers around `std::accumulate` to make the code more concise and less bug-prone: #162129. With `std::accumulate`, it's the initial value that determines the accumulator type. `llvm::sum_of` and `llvm::product_of` pick the right accumulator type based on the range element type. Found some funny bugs like a local accumulate helper that calculated a sum with initial value of 1 -- we didn't hit the bug because the code was actually dead...
1 parent 7eee672 commit 0820266

File tree

32 files changed

+60
-114
lines changed

32 files changed

+60
-114
lines changed

mlir/examples/toy/Ch2/mlir/MLIRGen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,7 @@ class MLIRGenImpl {
264264
// The attribute is a vector with a floating point value per element
265265
// (number) in the array, see `collectData()` below for more details.
266266
std::vector<double> data;
267-
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
268-
std::multiplies<int>()));
267+
data.reserve(llvm::product_of(lit.getDims()));
269268
collectData(lit, data);
270269

271270
// The type of this attribute is tensor of 64-bit floating-point with the

mlir/examples/toy/Ch3/mlir/MLIRGen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,7 @@ class MLIRGenImpl {
264264
// The attribute is a vector with a floating point value per element
265265
// (number) in the array, see `collectData()` below for more details.
266266
std::vector<double> data;
267-
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
268-
std::multiplies<int>()));
267+
data.reserve(llvm::product_of(lit.getDims()));
269268
collectData(lit, data);
270269

271270
// The type of this attribute is tensor of 64-bit floating-point with the

mlir/examples/toy/Ch4/mlir/MLIRGen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,7 @@ class MLIRGenImpl {
268268
// The attribute is a vector with a floating point value per element
269269
// (number) in the array, see `collectData()` below for more details.
270270
std::vector<double> data;
271-
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
272-
std::multiplies<int>()));
271+
data.reserve(llvm::product_of(lit.getDims()));
273272
collectData(lit, data);
274273

275274
// The type of this attribute is tensor of 64-bit floating-point with the

mlir/examples/toy/Ch5/mlir/MLIRGen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,7 @@ class MLIRGenImpl {
268268
// The attribute is a vector with a floating point value per element
269269
// (number) in the array, see `collectData()` below for more details.
270270
std::vector<double> data;
271-
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
272-
std::multiplies<int>()));
271+
data.reserve(llvm::product_of(lit.getDims()));
273272
collectData(lit, data);
274273

275274
// The type of this attribute is tensor of 64-bit floating-point with the

mlir/examples/toy/Ch6/mlir/MLIRGen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,7 @@ class MLIRGenImpl {
268268
// The attribute is a vector with a floating point value per element
269269
// (number) in the array, see `collectData()` below for more details.
270270
std::vector<double> data;
271-
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
272-
std::multiplies<int>()));
271+
data.reserve(llvm::product_of(lit.getDims()));
273272
collectData(lit, data);
274273

275274
// The type of this attribute is tensor of 64-bit floating-point with the

mlir/examples/toy/Ch7/mlir/MLIRGen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,7 @@ class MLIRGenImpl {
405405
// The attribute is a vector with a floating point value per element
406406
// (number) in the array, see `collectData()` below for more details.
407407
std::vector<double> data;
408-
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
409-
std::multiplies<int>()));
408+
data.reserve(llvm::product_of(lit.getDims()));
410409
collectData(lit, data);
411410

412411
// The type of this attribute is tensor of 64-bit floating-point with the

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/TypeRange.h"
2323
#include "mlir/IR/Value.h"
2424
#include "mlir/Transforms/DialectConversion.h"
25+
#include "llvm/ADT/STLExtras.h"
2526
#include <cstdint>
2627
#include <numeric>
2728

@@ -110,9 +111,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
110111
{TypeAttr::get(memrefType.getElementType())}));
111112

112113
IndexType indexType = builder.getIndexType();
113-
int64_t numElements = std::accumulate(memrefType.getShape().begin(),
114-
memrefType.getShape().end(), int64_t{1},
115-
std::multiplies<int64_t>());
114+
int64_t numElements = llvm::product_of(memrefType.getShape());
116115
emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
117116
builder, loc, indexType, builder.getIndexAttr(numElements));
118117

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1919
#include "mlir/IR/PatternMatch.h"
2020
#include "mlir/Transforms/DialectConversion.h"
21+
#include "llvm/ADT/STLExtras.h"
2122

2223
#include <numeric>
2324

@@ -70,8 +71,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
7071

7172
// Calculate the product of all elements in 'newShape' except for the -1
7273
// placeholder, which we discard by negating the result.
73-
int64_t totalSizeNoPlaceholder = -std::accumulate(
74-
newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
74+
int64_t totalSizeNoPlaceholder = -llvm::product_of(newShape);
7575

7676
// If there is a 0 component in 'newShape', resolve the placeholder as
7777
// 0.

mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Pass/Pass.h"
2121
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2222

23+
#include "llvm/ADT/STLExtras.h"
2324
#include "llvm/Support/DebugLog.h"
2425

2526
#include <numeric>
@@ -265,8 +266,7 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
265266
if (isPacked)
266267
src = collapseLastDim(rewriter, src);
267268
int64_t rows = vecShape[0];
268-
int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
269-
std::multiplies<int64_t>());
269+
int64_t cols = llvm::product_of(vecShape.drop_front());
270270
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
271271

272272
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
@@ -336,8 +336,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
336336

337337
ArrayRef<int64_t> shape = vecTy.getShape();
338338
int64_t rows = shape[0];
339-
int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
340-
std::multiplies<int64_t>());
339+
int64_t cols = llvm::product_of(shape.drop_front());
341340
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
342341

343342
return amx::TileLoadOp::create(rewriter, loc, tileType, buf,

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/IR/Builders.h"
2727
#include "mlir/Pass/Pass.h"
2828
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29+
#include "llvm/ADT/STLExtras.h"
2930

3031
namespace mlir {
3132
#define GEN_PASS_DEF_CONVERTVECTORTOSCF
@@ -760,8 +761,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
760761
if (vectorType.getRank() != 1) {
761762
// Flatten n-D vectors to 1D. This is done to allow indexing with a
762763
// non-constant value.
763-
auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
764-
std::multiplies<int64_t>());
764+
int64_t flatLength = llvm::product_of(shape);
765765
auto flatVectorType =
766766
VectorType::get({flatLength}, vectorType.getElementType());
767767
value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);

0 commit comments

Comments
 (0)