Skip to content

Commit ac25641

Browse files
[mlir][tosa] Fold PadOp to tensor operations
Add a canonicalizer to enable folding of explicit padding operations to implicit padding attributes of tensor operations. This enables folding to the following operations: - Conv2d - DepthwiseConv2d - AvgPool2d - MaxPool2d Signed-off-by: Georgios Pinitas <[email protected]> Co-authored-by: Rob-Hughes-Arm <[email protected]>
1 parent ec9546d commit ac25641

File tree

3 files changed

+253
-35
lines changed

3 files changed

+253
-35
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
107107
LogicalResult verifyOutputZeroPoint(int64_t zp);
108108
}];
109109

110+
let hasCanonicalizer = 1;
110111
let hasVerifier = 1;
111112
}
112113

@@ -153,6 +154,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
153154
}];
154155

155156
let builders = [Tosa_ConvOpQuantInfoBuilder];
157+
158+
let hasCanonicalizer = 1;
156159
let hasVerifier = 1;
157160
}
158161

@@ -244,6 +247,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
244247
}];
245248

246249
let builders = [Tosa_ConvOpQuantInfoBuilder];
250+
251+
let hasCanonicalizer = 1;
247252
let hasVerifier = 1;
248253
}
249254

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 169 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,175 @@ using namespace mlir::tosa;
3939
// Operator Canonicalizers.
4040
//===----------------------------------------------------------------------===//
4141

42+
//===----------------------------------------------------------------------===//
43+
// Tensor Data Engine Operators.
44+
//===----------------------------------------------------------------------===//
45+
46+
namespace {
47+
template <typename OpTy>
48+
struct PoolPadFoldAdaptor;
49+
50+
template <>
51+
struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
52+
static void replaceOpWithNewPad(PatternRewriter &rewriter,
53+
tosa::AvgPool2dOp op, Value padInput,
54+
ArrayRef<int64_t> newPad) {
55+
rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
56+
op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
57+
op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
58+
op.getAccType());
59+
}
60+
};
61+
62+
template <>
63+
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
64+
static void replaceOpWithNewPad(PatternRewriter &rewriter,
65+
tosa::MaxPool2dOp op, Value padInput,
66+
ArrayRef<int64_t> newPad) {
67+
rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
68+
op, op.getType(), padInput, op.getKernel(), op.getStride(),
69+
rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
70+
}
71+
};
72+
73+
template <typename OpTy>
74+
struct ConvPadFoldAdaptor {
75+
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
76+
Value padInput, ArrayRef<int64_t> newPad) {
77+
rewriter.replaceOpWithNewOp<OpTy>(
78+
op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
79+
op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
80+
op.getDilationAttr(), op.getAccType(), op.getLocalBound());
81+
}
82+
};
83+
84+
// Pattern attempts to fold a `tosa.pad` operator to a following tensor
85+
// operation like `tosa.conv2d` by merging the padding associated with the
86+
// pad operator directly to the implicit padding of the tensor operation.
87+
// This helps eliminate the explicit padding operator if unused.
88+
template <typename OpTy, typename AdaptorTy>
89+
struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
90+
using OpRewritePattern<OpTy>::OpRewritePattern;
91+
92+
LogicalResult matchAndRewrite(OpTy tensorOp,
93+
PatternRewriter &rewriter) const override {
94+
// Check producer is a tosa::PadOp
95+
auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
96+
if (!padOp)
97+
return rewriter.notifyMatchFailure(tensorOp,
98+
"Producer must be a tosa::PadOp.");
99+
100+
// Validate that tensor operation has sane padding
101+
const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
102+
if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
103+
return rewriter.notifyMatchFailure(
104+
tensorOp, "Tensor operation padding shall have 4 elements.");
105+
106+
// Validate tosa::PadOp padding
107+
DenseIntElementsAttr padOpPadding;
108+
if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
109+
return rewriter.notifyMatchFailure(
110+
tensorOp,
111+
"The `padding` input specified on the tosa::PadOp must be constant.");
112+
}
113+
// N_before, N_after, H_before, H_after, W_before, W_after, C_before,
114+
// C_after
115+
if (padOpPadding.size() != 8)
116+
return rewriter.notifyMatchFailure(tensorOp,
117+
"Pad padding should have 8 elements.");
118+
int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
119+
int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
120+
int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
121+
int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
122+
int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
123+
int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
124+
int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
125+
int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
126+
127+
if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
128+
return rewriter.notifyMatchFailure(
129+
tensorOp, "Folding padding in N or C dimensions is not supported.");
130+
131+
// Fold padding from Pad into the tensor operation
132+
// 4 elements - pad_top, pad_bottom, pad_left, pad_right
133+
SmallVector<int64_t> foldedPad(tensorOpPad.size());
134+
foldedPad[0] = padHBefore + tensorOpPad[0];
135+
foldedPad[1] = padHAfter + tensorOpPad[1];
136+
foldedPad[2] = padWBefore + tensorOpPad[2];
137+
foldedPad[3] = padWAfter + tensorOpPad[3];
138+
139+
// Replace operator
140+
AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
141+
foldedPad);
142+
143+
return success();
144+
}
145+
};
146+
} // namespace
147+
148+
void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
149+
MLIRContext *context) {
150+
results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
151+
PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
152+
context);
153+
}
154+
155+
void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
156+
MLIRContext *context) {
157+
results.add<
158+
FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
159+
context);
160+
}
161+
162+
void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
163+
MLIRContext *context) {
164+
results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
165+
ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
166+
context);
167+
}
168+
169+
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
170+
using OpRewritePattern::OpRewritePattern;
171+
172+
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
173+
PatternRewriter &rewriter) const override {
174+
Value input = op.getInput();
175+
Value output = op.getOutput();
176+
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
177+
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
178+
179+
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
180+
return failure();
181+
}
182+
183+
// If the output and input shapes are 1x1, then this is a no op.
184+
ArrayRef<int64_t> outputShape = outputType.getShape();
185+
if (outputShape[1] != 1 || outputShape[2] != 1) {
186+
return failure();
187+
}
188+
189+
ArrayRef<int64_t> inputShape = inputType.getShape();
190+
if (inputShape[1] != 1 || inputShape[2] != 1) {
191+
return failure();
192+
}
193+
194+
rewriter.replaceOp(op, input);
195+
return success();
196+
}
197+
};
198+
199+
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
200+
MLIRContext *context) {
201+
results.add<MaxPool2dIsNoOp,
202+
FoldPadToTensorOp<tosa::MaxPool2dOp,
203+
PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
204+
context);
205+
}
206+
207+
//===----------------------------------------------------------------------===//
208+
// Data Layout / Memory Reinterpretation.
209+
//===----------------------------------------------------------------------===//
210+
42211
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
43212
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44213

@@ -175,41 +344,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175344
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176345
}
177346

178-
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
179-
using OpRewritePattern::OpRewritePattern;
180-
181-
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
182-
PatternRewriter &rewriter) const override {
183-
Value input = op.getInput();
184-
Value output = op.getOutput();
185-
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
186-
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
187-
188-
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
189-
return failure();
190-
}
191-
192-
// If the output and input shapes are 1x1, then this is a no op.
193-
ArrayRef<int64_t> outputShape = outputType.getShape();
194-
if (outputShape[1] != 1 || outputShape[2] != 1) {
195-
return failure();
196-
}
197-
198-
ArrayRef<int64_t> inputShape = inputType.getShape();
199-
if (inputShape[1] != 1 || inputShape[2] != 1) {
200-
return failure();
201-
}
202-
203-
rewriter.replaceOp(op, input);
204-
return success();
205-
}
206-
};
207-
208-
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
209-
MLIRContext *context) {
210-
results.add<MaxPool2dIsNoOp>(context);
211-
}
212-
213347
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
214348
using OpRewritePattern::OpRewritePattern;
215349

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,85 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
99

1010
// -----
1111

12+
// CHECK-LABEL: @pad_wh_avg_pool2d_fold
13+
func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
14+
// CHECK-NOT: tosa.pad
15+
// CHECK: tosa.avg_pool2d
16+
// CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
17+
%pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
18+
%pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
19+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
20+
%output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
21+
%padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x11x9x3xf32>
22+
%pool = tosa.avg_pool2d %padded, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 2, 2>} : (tensor<1x11x9x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x6x5x3xf32>
23+
return %pool : tensor<1x6x5x3xf32>
24+
}
25+
26+
// -----
27+
28+
// CHECK-LABEL: @pad_wh_conv2d_fold
29+
func.func @pad_wh_conv2d_fold(%input: tensor<1x8x4x3xf32>, %weight: tensor<1x3x3x3xf32>, %bias: tensor<1xf32>) -> tensor<1x10x8x1xf32> {
30+
// CHECK-NOT: tosa.pad
31+
// CHECK: tosa.conv2d
32+
// CHECK-SAME: pad = array<i64: 2, 2, 3, 3>
33+
%pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
34+
%pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
35+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
36+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
37+
%padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x10x8x3xf32>
38+
%conv = tosa.conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<1x10x8x3xf32>, tensor<1x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x8x1xf32>
39+
return %conv : tensor<1x10x8x1xf32>
40+
}
41+
42+
// -----
43+
44+
// CHECK-LABEL: @pad_bwh_conv2d_nofold
45+
func.func @pad_bwh_conv2d_nofold(%input: tensor<1x8x4x3xf32>, %weight: tensor<1x3x3x3xf32>, %bias: tensor<1xf32>) -> tensor<3x10x8x1xf32> {
46+
// CHECK: tosa.pad
47+
// CHECK: tosa.conv2d
48+
// CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
49+
%pad_shape = tosa.const_shape { values = dense<[1, 1, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
50+
%pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
51+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
52+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
53+
%padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<3x10x8x3xf32>
54+
%conv = tosa.conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<3x10x8x3xf32>, tensor<1x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3x10x8x1xf32>
55+
return %conv : tensor<3x10x8x1xf32>
56+
}
57+
58+
// -----
59+
60+
// CHECK-LABEL: @pad_wh_depthwise_conv2d_fold
61+
func.func @pad_wh_depthwise_conv2d_fold(%input: tensor<1x8x4x3xf32>, %weight: tensor<3x3x3x1xf32>, %bias: tensor<3xf32>) -> tensor<1x10x8x3xf32> {
62+
// CHECK-NOT: tosa.pad
63+
// CHECK: tosa.depthwise_conv2d
64+
// CHECK-SAME: pad = array<i64: 2, 2, 3, 3>
65+
%pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
66+
%pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
67+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
68+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
69+
%padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x10x8x3xf32>
70+
%conv = tosa.depthwise_conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<1x10x8x3xf32>, tensor<3x3x3x1xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x8x3xf32>
71+
return %conv : tensor<1x10x8x3xf32>
72+
}
73+
74+
// -----
75+
76+
77+
// CHECK-LABEL: @pad_wh_max_pool2d_fold
78+
func.func @pad_wh_max_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
79+
// CHECK-NOT: tosa.pad
80+
// CHECK: tosa.max_pool2d
81+
// CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
82+
%pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
83+
%pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
84+
%padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x11x9x3xf32>
85+
%pool = tosa.max_pool2d %padded {kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 2, 2>} : (tensor<1x11x9x3xf32>) -> tensor<1x6x5x3xf32>
86+
return %pool : tensor<1x6x5x3xf32>
87+
}
88+
89+
// -----
90+
1291
// CHECK-LABEL: @add_bcast_zero_int
1392
func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> {
1493
// CHECK-NOT: tosa.add

0 commit comments

Comments
 (0)