Skip to content

Commit e20db31

Browse files
[mlir][tosa] Fold PadOp to tensor operations
Add a canonicalizer to enable folding of explicit padding operations to implicit padding attributes of tensor operations. This enables folding to the following operations: - Conv2d - DepthwiseConv2d - AvgPool2d - MaxPool2d Signed-off-by: Georgios Pinitas <[email protected]> Co-authored-by: Rob-Hughes-Arm <[email protected]>
1 parent 728320f commit e20db31

File tree

3 files changed

+428
-35
lines changed

3 files changed

+428
-35
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
107107
LogicalResult verifyOutputZeroPoint(int64_t zp);
108108
}];
109109

110+
let hasCanonicalizer = 1;
110111
let hasVerifier = 1;
111112
}
112113

@@ -153,6 +154,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
153154
}];
154155

155156
let builders = [Tosa_ConvOpQuantInfoBuilder];
157+
158+
let hasCanonicalizer = 1;
156159
let hasVerifier = 1;
157160
}
158161

@@ -244,6 +247,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
244247
}];
245248

246249
let builders = [Tosa_ConvOpQuantInfoBuilder];
250+
251+
let hasCanonicalizer = 1;
247252
let hasVerifier = 1;
248253
}
249254

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 271 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,277 @@ using namespace mlir::tosa;
3939
// Operator Canonicalizers.
4040
//===----------------------------------------------------------------------===//
4141

42+
//===----------------------------------------------------------------------===//
43+
// Tensor Data Engine Operators.
44+
//===----------------------------------------------------------------------===//
45+
46+
// Check that the zero point of the tensor and padding operations are aligned.
47+
bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
48+
// Check that padConst is a constant value and a scalar tensor
49+
DenseElementsAttr padConstAttr;
50+
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
51+
(padConstAttr.size() != 1)) {
52+
return false;
53+
}
54+
55+
// Check that floating point pad is zero
56+
if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57+
float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
58+
return padConstVal == 0.0f;
59+
}
60+
61+
// Check that the zp and padConst align for the integer (quantized) case
62+
if (auto padConstIntAttr =
63+
mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
64+
DenseIntElementsAttr zpAttr;
65+
// Check that zp is a constant value and a scalar tensor
66+
if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
67+
return false;
68+
}
69+
70+
// Check equality
71+
int64_t zpVal = (*zpAttr.begin()).getSExtValue();
72+
int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
73+
return zpVal == padConstVal;
74+
}
75+
76+
// Bail-out on unsupported type
77+
return false;
78+
}
79+
80+
namespace {
81+
template <typename OpTy>
82+
struct PoolPadFoldAdaptor;
83+
84+
template <>
85+
struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
86+
using OpTy = tosa::AvgPool2dOp;
87+
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
88+
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
89+
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
90+
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
91+
return false;
92+
return true;
93+
}
94+
static bool checkPadConstCompliance(OpTy op, Value padConst) {
95+
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
96+
}
97+
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
98+
Value padInput, ArrayRef<int64_t> newPad) {
99+
rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
100+
op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
101+
op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
102+
op.getAccType());
103+
}
104+
};
105+
106+
template <>
107+
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
108+
using OpTy = tosa::MaxPool2dOp;
109+
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
110+
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
111+
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
112+
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
113+
return false;
114+
return true;
115+
}
116+
static bool checkPadConstCompliance(OpTy, Value padConst) {
117+
// Check that padConst is a constant value and a scalar tensor
118+
DenseElementsAttr padConstAttr;
119+
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
120+
padConstAttr.size() != 1) {
121+
return false;
122+
}
123+
124+
// Pad needs to be in the minimum value to be able to merge
125+
if (auto padConstFpAttr =
126+
mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127+
const APFloat padConstVal = *padConstFpAttr.begin();
128+
const APFloat lowestVal =
129+
APFloat::getLargest(padConstVal.getSemantics(), true);
130+
return padConstVal == lowestVal;
131+
} else if (auto padConstIntAttr =
132+
mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
133+
const APInt padConstVal = *padConstIntAttr.begin();
134+
const APInt lowestVal =
135+
APInt::getSignedMinValue(padConstVal.getBitWidth());
136+
return padConstVal == lowestVal;
137+
}
138+
139+
// Bail-out on unsupported type
140+
return false;
141+
}
142+
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
143+
Value padInput, ArrayRef<int64_t> newPad) {
144+
rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
145+
op, op.getType(), padInput, op.getKernel(), op.getStride(),
146+
rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
147+
}
148+
};
149+
150+
template <typename OpTy>
151+
struct ConvPadFoldAdaptor {
152+
static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
153+
return true;
154+
}
155+
static bool checkPadConstCompliance(OpTy op, Value padConst) {
156+
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
157+
}
158+
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
159+
Value padInput, ArrayRef<int64_t> newPad) {
160+
rewriter.replaceOpWithNewOp<OpTy>(
161+
op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
162+
op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
163+
op.getDilationAttr(), op.getAccType(), op.getLocalBound());
164+
}
165+
};
166+
167+
// Pattern attempts to fold a `tosa.pad` operator to a following tensor
168+
// operation like `tosa.conv2d` by merging the padding associated with the
169+
// pad operator directly to the implicit padding of the tensor operation.
170+
// This helps eliminate the explicit padding operator if unused.
171+
template <typename OpTy, typename AdaptorTy>
172+
struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
173+
using OpRewritePattern<OpTy>::OpRewritePattern;
174+
175+
LogicalResult matchAndRewrite(OpTy tensorOp,
176+
PatternRewriter &rewriter) const override {
177+
// Check producer is a tosa::PadOp
178+
auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
179+
if (!padOp)
180+
return rewriter.notifyMatchFailure(tensorOp,
181+
"Producer must be a tosa::PadOp.");
182+
183+
// Validate that tensor operation has sane padding
184+
const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
185+
if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
186+
return rewriter.notifyMatchFailure(
187+
tensorOp, "Tensor operation padding shall have 4 elements.");
188+
189+
// Validate tosa::PadOp padding
190+
DenseIntElementsAttr padOpPadding;
191+
if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
192+
return rewriter.notifyMatchFailure(
193+
tensorOp,
194+
"The `padding` input specified on the tosa::PadOp must be constant.");
195+
}
196+
// N_before, N_after, H_before, H_after, W_before, W_after, C_before,
197+
// C_after
198+
if (padOpPadding.size() != 8)
199+
return rewriter.notifyMatchFailure(tensorOp,
200+
"Pad padding should have 8 elements.");
201+
int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
202+
int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
203+
int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
204+
int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
205+
int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
206+
int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
207+
int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
208+
int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
209+
210+
if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
211+
return rewriter.notifyMatchFailure(
212+
tensorOp, "Folding padding in N or C dimensions is not supported.");
213+
214+
// Fold padding from Pad into the tensor operation
215+
// 4 elements - pad_top, pad_bottom, pad_left, pad_right
216+
SmallVector<int64_t> foldedPad(tensorOpPad.size());
217+
foldedPad[0] = padHBefore + tensorOpPad[0];
218+
foldedPad[1] = padHAfter + tensorOpPad[1];
219+
foldedPad[2] = padWBefore + tensorOpPad[2];
220+
foldedPad[3] = padWAfter + tensorOpPad[3];
221+
222+
// Check kernel related restrictions
223+
if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
224+
return rewriter.notifyMatchFailure(
225+
tensorOp, "Padding size not aligned with kernel restrictions.");
226+
}
227+
228+
// Check padding constant restrictions
229+
if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
230+
return rewriter.notifyMatchFailure(
231+
tensorOp,
232+
"Padding constant is not aligned with operator zero-point.");
233+
}
234+
235+
// Check that padding doesn't grow more than 8K level (8192) for now
236+
if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
237+
return rewriter.notifyMatchFailure(
238+
tensorOp, "Padding size more than the 8K level limit.");
239+
}
240+
241+
// Create operator
242+
AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
243+
foldedPad);
244+
245+
return success();
246+
}
247+
};
248+
} // namespace
249+
250+
void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
251+
MLIRContext *context) {
252+
results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
253+
PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
254+
context);
255+
}
256+
257+
void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
258+
MLIRContext *context) {
259+
results.add<
260+
FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
261+
context);
262+
}
263+
264+
void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
265+
MLIRContext *context) {
266+
results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
267+
ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
268+
context);
269+
}
270+
271+
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
272+
using OpRewritePattern::OpRewritePattern;
273+
274+
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
275+
PatternRewriter &rewriter) const override {
276+
Value input = op.getInput();
277+
Value output = op.getOutput();
278+
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
279+
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
280+
281+
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
282+
return failure();
283+
}
284+
285+
// If the output and input shapes are 1x1, then this is a no op.
286+
ArrayRef<int64_t> outputShape = outputType.getShape();
287+
if (outputShape[1] != 1 || outputShape[2] != 1) {
288+
return failure();
289+
}
290+
291+
ArrayRef<int64_t> inputShape = inputType.getShape();
292+
if (inputShape[1] != 1 || inputShape[2] != 1) {
293+
return failure();
294+
}
295+
296+
rewriter.replaceOp(op, input);
297+
return success();
298+
}
299+
};
300+
301+
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
302+
MLIRContext *context) {
303+
results.add<MaxPool2dIsNoOp,
304+
FoldPadToTensorOp<tosa::MaxPool2dOp,
305+
PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
306+
context);
307+
}
308+
309+
//===----------------------------------------------------------------------===//
310+
// Data Layout / Memory Reinterpretation.
311+
//===----------------------------------------------------------------------===//
312+
42313
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
43314
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44315

@@ -175,41 +446,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175446
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176447
}
177448

178-
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
179-
using OpRewritePattern::OpRewritePattern;
180-
181-
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
182-
PatternRewriter &rewriter) const override {
183-
Value input = op.getInput();
184-
Value output = op.getOutput();
185-
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
186-
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
187-
188-
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
189-
return failure();
190-
}
191-
192-
// If the output and input shapes are 1x1, then this is a no op.
193-
ArrayRef<int64_t> outputShape = outputType.getShape();
194-
if (outputShape[1] != 1 || outputShape[2] != 1) {
195-
return failure();
196-
}
197-
198-
ArrayRef<int64_t> inputShape = inputType.getShape();
199-
if (inputShape[1] != 1 || inputShape[2] != 1) {
200-
return failure();
201-
}
202-
203-
rewriter.replaceOp(op, input);
204-
return success();
205-
}
206-
};
207-
208-
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
209-
MLIRContext *context) {
210-
results.add<MaxPool2dIsNoOp>(context);
211-
}
212-
213449
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
214450
using OpRewritePattern::OpRewritePattern;
215451

0 commit comments

Comments
 (0)