Skip to content

Commit ffc7db9

Browse files
[mlir][tosa] Fold PadOp to tensor operations
Add a canonicalizer to enable folding of explicit padding operations to implicit padding attributes of tensor operations. This enables folding to the following operations: - Conv2d - DepthwiseConv2d - AvgPool2d - MaxPool2d Signed-off-by: Georgios Pinitas <[email protected]> Co-authored-by: Rob-Hughes-Arm <[email protected]>
1 parent 728320f commit ffc7db9

File tree

3 files changed

+424
-35
lines changed

3 files changed

+424
-35
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
107107
LogicalResult verifyOutputZeroPoint(int64_t zp);
108108
}];
109109

110+
let hasCanonicalizer = 1;
110111
let hasVerifier = 1;
111112
}
112113

@@ -153,6 +154,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
153154
}];
154155

155156
let builders = [Tosa_ConvOpQuantInfoBuilder];
157+
158+
let hasCanonicalizer = 1;
156159
let hasVerifier = 1;
157160
}
158161

@@ -244,6 +247,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
244247
}];
245248

246249
let builders = [Tosa_ConvOpQuantInfoBuilder];
250+
251+
let hasCanonicalizer = 1;
247252
let hasVerifier = 1;
248253
}
249254

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 267 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,273 @@ using namespace mlir::tosa;
3939
// Operator Canonicalizers.
4040
//===----------------------------------------------------------------------===//
4141

42+
//===----------------------------------------------------------------------===//
43+
// Tensor Data Engine Operators.
44+
//===----------------------------------------------------------------------===//
45+
46+
// Check that the zero point of the tensor and padding operations are aligned.
47+
bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
48+
// Check that padConst is a constant value and a scalar tensor
49+
DenseElementsAttr padConstAttr;
50+
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
51+
(padConstAttr.size() != 1)) {
52+
return false;
53+
}
54+
55+
// Check that floating point pad is zero
56+
if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57+
float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
58+
return padConstVal == 0.0f;
59+
}
60+
61+
// Check that the zp and padConst align for the integer (quantized) case
62+
if (auto padConstIntAttr =
63+
mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
64+
DenseIntElementsAttr zpAttr;
65+
// Check that zp is a constant value and a scalar tensor
66+
if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
67+
return false;
68+
}
69+
70+
// Check equality
71+
int64_t zpVal = (*zpAttr.begin()).getSExtValue();
72+
int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
73+
return zpVal == padConstVal;
74+
}
75+
76+
// Bail-out on unsupported type
77+
return false;
78+
}
79+
80+
namespace {
81+
template <typename OpTy>
82+
struct PoolPadFoldAdaptor;
83+
84+
template <>
85+
struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
86+
using OpTy = tosa::AvgPool2dOp;
87+
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
88+
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
89+
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
90+
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
91+
return false;
92+
return true;
93+
}
94+
static bool checkPadConstCompliance(OpTy op, Value padConst) {
95+
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
96+
}
97+
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
98+
Value padInput, ArrayRef<int64_t> newPad) {
99+
rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
100+
op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
101+
op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
102+
op.getAccType());
103+
}
104+
};
105+
106+
template <>
107+
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
108+
using OpTy = tosa::MaxPool2dOp;
109+
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
110+
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
111+
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
112+
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
113+
return false;
114+
return true;
115+
}
116+
static bool checkPadConstCompliance(OpTy, Value padConst) {
117+
// Check that padConst is a constant value and a scalar tensor
118+
DenseElementsAttr padConstAttr;
119+
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
120+
padConstAttr.size() != 1) {
121+
return false;
122+
}
123+
124+
// Pad needs to be in the minimum value to be able to merge
125+
if (auto padConstFpAttr =
126+
mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127+
float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
128+
return padConstVal == std::numeric_limits<float>::lowest();
129+
} else if (auto padConstIntAttr =
130+
mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
131+
int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
132+
return padConstVal == std::numeric_limits<int32_t>::lowest();
133+
}
134+
135+
// Bail-out on unsupported type
136+
return false;
137+
}
138+
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
139+
Value padInput, ArrayRef<int64_t> newPad) {
140+
rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
141+
op, op.getType(), padInput, op.getKernel(), op.getStride(),
142+
rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
143+
}
144+
};
145+
146+
template <typename OpTy>
147+
struct ConvPadFoldAdaptor {
148+
static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
149+
return true;
150+
}
151+
static bool checkPadConstCompliance(OpTy op, Value padConst) {
152+
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
153+
}
154+
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
155+
Value padInput, ArrayRef<int64_t> newPad) {
156+
rewriter.replaceOpWithNewOp<OpTy>(
157+
op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
158+
op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
159+
op.getDilationAttr(), op.getAccType(), op.getLocalBound());
160+
}
161+
};
162+
163+
// Pattern attempts to fold a `tosa.pad` operator to a following tensor
164+
// operation like `tosa.conv2d` by merging the padding associated with the
165+
// pad operator directly to the implicit padding of the tensor operation.
166+
// This helps eliminate the explicit padding operator if unused.
167+
template <typename OpTy, typename AdaptorTy>
168+
struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
169+
using OpRewritePattern<OpTy>::OpRewritePattern;
170+
171+
LogicalResult matchAndRewrite(OpTy tensorOp,
172+
PatternRewriter &rewriter) const override {
173+
// Check producer is a tosa::PadOp
174+
auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
175+
if (!padOp)
176+
return rewriter.notifyMatchFailure(tensorOp,
177+
"Producer must be a tosa::PadOp.");
178+
179+
// Validate that tensor operation has sane padding
180+
const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
181+
if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
182+
return rewriter.notifyMatchFailure(
183+
tensorOp, "Tensor operation padding shall have 4 elements.");
184+
185+
// Validate tosa::PadOp padding
186+
DenseIntElementsAttr padOpPadding;
187+
if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
188+
return rewriter.notifyMatchFailure(
189+
tensorOp,
190+
"The `padding` input specified on the tosa::PadOp must be constant.");
191+
}
192+
// N_before, N_after, H_before, H_after, W_before, W_after, C_before,
193+
// C_after
194+
if (padOpPadding.size() != 8)
195+
return rewriter.notifyMatchFailure(tensorOp,
196+
"Pad padding should have 8 elements.");
197+
int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
198+
int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
199+
int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
200+
int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
201+
int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
202+
int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
203+
int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
204+
int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
205+
206+
if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
207+
return rewriter.notifyMatchFailure(
208+
tensorOp, "Folding padding in N or C dimensions is not supported.");
209+
210+
// Fold padding from Pad into the tensor operation
211+
// 4 elements - pad_top, pad_bottom, pad_left, pad_right
212+
SmallVector<int64_t> foldedPad(tensorOpPad.size());
213+
foldedPad[0] = padHBefore + tensorOpPad[0];
214+
foldedPad[1] = padHAfter + tensorOpPad[1];
215+
foldedPad[2] = padWBefore + tensorOpPad[2];
216+
foldedPad[3] = padWAfter + tensorOpPad[3];
217+
218+
// Check kernel related restrictions
219+
if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
220+
return rewriter.notifyMatchFailure(
221+
tensorOp, "Padding size not aligned with kernel restrictions.");
222+
}
223+
224+
// Check padding constant restrictions
225+
if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
226+
return rewriter.notifyMatchFailure(
227+
tensorOp,
228+
"Padding constant is not aligned with operator zero-point.");
229+
}
230+
231+
// Check that padding doesn't grow more than 8K level (8192) for now
232+
if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
233+
return rewriter.notifyMatchFailure(
234+
tensorOp, "Padding size more than the 8K level limit.");
235+
}
236+
237+
// Create operator
238+
AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
239+
foldedPad);
240+
241+
return success();
242+
}
243+
};
244+
} // namespace
245+
246+
void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
247+
MLIRContext *context) {
248+
results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
249+
PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
250+
context);
251+
}
252+
253+
void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
254+
MLIRContext *context) {
255+
results.add<
256+
FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
257+
context);
258+
}
259+
260+
void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
261+
MLIRContext *context) {
262+
results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
263+
ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
264+
context);
265+
}
266+
267+
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
268+
using OpRewritePattern::OpRewritePattern;
269+
270+
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
271+
PatternRewriter &rewriter) const override {
272+
Value input = op.getInput();
273+
Value output = op.getOutput();
274+
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
275+
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
276+
277+
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
278+
return failure();
279+
}
280+
281+
// If the output and input shapes are 1x1, then this is a no op.
282+
ArrayRef<int64_t> outputShape = outputType.getShape();
283+
if (outputShape[1] != 1 || outputShape[2] != 1) {
284+
return failure();
285+
}
286+
287+
ArrayRef<int64_t> inputShape = inputType.getShape();
288+
if (inputShape[1] != 1 || inputShape[2] != 1) {
289+
return failure();
290+
}
291+
292+
rewriter.replaceOp(op, input);
293+
return success();
294+
}
295+
};
296+
297+
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
298+
MLIRContext *context) {
299+
results.add<MaxPool2dIsNoOp,
300+
FoldPadToTensorOp<tosa::MaxPool2dOp,
301+
PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
302+
context);
303+
}
304+
305+
//===----------------------------------------------------------------------===//
306+
// Data Layout / Memory Reinterpretation.
307+
//===----------------------------------------------------------------------===//
308+
42309
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
43310
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44311

@@ -175,41 +442,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175442
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176443
}
177444

178-
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
179-
using OpRewritePattern::OpRewritePattern;
180-
181-
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
182-
PatternRewriter &rewriter) const override {
183-
Value input = op.getInput();
184-
Value output = op.getOutput();
185-
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
186-
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
187-
188-
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
189-
return failure();
190-
}
191-
192-
// If the output and input shapes are 1x1, then this is a no op.
193-
ArrayRef<int64_t> outputShape = outputType.getShape();
194-
if (outputShape[1] != 1 || outputShape[2] != 1) {
195-
return failure();
196-
}
197-
198-
ArrayRef<int64_t> inputShape = inputType.getShape();
199-
if (inputShape[1] != 1 || inputShape[2] != 1) {
200-
return failure();
201-
}
202-
203-
rewriter.replaceOp(op, input);
204-
return success();
205-
}
206-
};
207-
208-
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
209-
MLIRContext *context) {
210-
results.add<MaxPool2dIsNoOp>(context);
211-
}
212-
213445
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
214446
using OpRewritePattern::OpRewritePattern;
215447

0 commit comments

Comments
 (0)