Skip to content

Commit c8b1c57

Browse files
Use macro in specialize.cpp pass
1 parent c39f831 commit c8b1c57

File tree

2 files changed

+15
-32
lines changed

2 files changed

+15
-32
lines changed

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -268,44 +268,25 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
268268
static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
269269
GenericOp genericOp) {
270270
SmallVector<int64_t> dilations, strides;
271+
#define CONV_OP_SPECIALIZER(ConvOpTy) \
272+
if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \
273+
return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
274+
strides); \
271275
// -----------------------------
272276
// Depthwise Convolution ops.
273277
// -----------------------------
274-
if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
275-
genericOp, &dilations, &strides))
276-
return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
277-
rewriter, genericOp, dilations, strides);
278-
if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
279-
genericOp, &dilations, &strides))
280-
return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
281-
rewriter, genericOp, dilations, strides);
282-
if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
283-
genericOp, &dilations, &strides))
284-
return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
285-
rewriter, genericOp, dilations, strides);
278+
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
279+
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
280+
CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
286281
// -----------------------------
287282
// Pooling ops.
288283
// -----------------------------
289-
if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
290-
&strides))
291-
return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
292-
dilations, strides);
293-
if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
294-
&strides))
295-
return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
296-
dilations, strides);
297-
if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
298-
&strides))
299-
return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
300-
dilations, strides);
301-
if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
302-
genericOp, &dilations, &strides))
303-
return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
304-
rewriter, genericOp, dilations, strides);
305-
if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
306-
genericOp, &dilations, &strides))
307-
return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
308-
rewriter, genericOp, dilations, strides);
284+
CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp);
285+
CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp);
286+
CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp);
287+
CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp);
288+
CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp);
289+
#undef CONV_OP_SPECIALIZER
309290
return failure();
310291
}
311292

mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ func.func @pooling_nhwc_min_unsigned_integer(%input: tensor<?x?x?x?xi32>, %filte
112112
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
113113
// CHECK-NOT: linalg.generic
114114

115+
// -----
116+
115117
func.func @pooling_nhwc_min_unsigned_float(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
116118
%0 = linalg.pooling_nhwc_min_unsigned
117119
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}

0 commit comments

Comments
 (0)